From dc85cb000e6aa4f137d98337149025b6a2f88ec8 Mon Sep 17 00:00:00 2001 From: Daniel Kroening Date: Thu, 5 Sep 2024 10:17:50 +0100 Subject: [PATCH] introduce zero_extend expression This introduces the zero_extend expression, which, given a bit-vector operand and a type, either a) pads the given operand with zeros from the left if the given type is wider than the type of the operand, or b) truncates the operand to the width of the given type if the given type is smaller than the operand, or c) reinterprets the operand as having the given type if the width of the type and the width of the operand match. This may differ from conversion if the types have different bit representations. This is easier to read and less prone to error than the current pattern, in which the operand is 1) converted to an unsigned type of the same width, and then 2) casted to an unsigned type of the wider width, and 3) finally casted to the target type. --- src/solvers/flattening/boolbv.cpp | 2 + src/solvers/floatbv/float_bv.cpp | 8 ++-- src/solvers/smt2/smt2_conv.cpp | 4 ++ .../smt2_incremental/convert_expr_to_smt.cpp | 13 ++++++ .../smt2_incremental_decision_procedure.cpp | 18 +++++++- src/util/bitvector_expr.cpp | 21 +++++++-- src/util/bitvector_expr.h | 44 +++++++++++++++++++ src/util/format_expr.cpp | 6 +++ src/util/irep_ids.def | 1 + src/util/lower_byte_operators.cpp | 19 ++++---- src/util/simplify_expr.cpp | 4 ++ src/util/simplify_expr_class.h | 2 + src/util/simplify_expr_int.cpp | 12 +++++ 13 files changed, 137 insertions(+), 17 deletions(-) diff --git a/src/solvers/flattening/boolbv.cpp b/src/solvers/flattening/boolbv.cpp index ad155246fad2..f1f1f7c9de70 100644 --- a/src/solvers/flattening/boolbv.cpp +++ b/src/solvers/flattening/boolbv.cpp @@ -165,6 +165,8 @@ bvt boolbvt::convert_bitvector(const exprt &expr) return convert_replication(to_replication_expr(expr)); else if(expr.id()==ID_extractbits) return convert_extractbits(to_extractbits_expr(expr)); + else if(expr.id() == ID_zero_extend) + return convert_bitvector(to_zero_extend_expr(expr).lower()); else if(expr.id()==ID_bitnot || expr.id()==ID_bitand || expr.id()==ID_bitor || expr.id()==ID_bitxor || expr.id()==ID_bitxnor || expr.id()==ID_bitnor || diff --git a/src/solvers/floatbv/float_bv.cpp b/src/solvers/floatbv/float_bv.cpp index 12e87f923bff..162f1e8cd0a5 100644 --- a/src/solvers/floatbv/float_bv.cpp +++ b/src/solvers/floatbv/float_bv.cpp @@ -692,8 +692,10 @@ exprt float_bvt::mul( // zero-extend the fractions (unpacked fraction has the hidden bit) typet new_fraction_type=unsignedbv_typet((spec.f+1)*2); - const exprt fraction1=typecast_exprt(unpacked1.fraction, new_fraction_type); - const exprt fraction2=typecast_exprt(unpacked2.fraction, new_fraction_type); + const exprt fraction1 = + zero_extend_exprt{unpacked1.fraction, new_fraction_type}; + const exprt fraction2 = + zero_extend_exprt{unpacked2.fraction, new_fraction_type}; // multiply the fractions unbiased_floatt result; @@ -750,7 +752,7 @@ exprt float_bvt::div( unsignedbv_typet(div_width)); // zero-extend fraction2 to match fraction1 - const typecast_exprt fraction2(unpacked2.fraction, fraction1.type()); + const zero_extend_exprt fraction2{unpacked2.fraction, fraction1.type()}; // divide fractions unbiased_floatt result; diff --git a/src/solvers/smt2/smt2_conv.cpp b/src/solvers/smt2/smt2_conv.cpp index fcbb43bf99a5..2402bb8ca920 100644 --- a/src/solvers/smt2/smt2_conv.cpp +++ b/src/solvers/smt2/smt2_conv.cpp @@ -2456,6 +2456,10 @@ void smt2_convt::convert_expr(const exprt &expr) { convert_expr(simplify_expr(to_bitreverse_expr(expr).lower(), ns)); } + else if(expr.id() == ID_zero_extend) + { + convert_expr(to_zero_extend_expr(expr).lower()); + } else if(expr.id() == ID_function_application) { const auto &function_application_expr = to_function_application_expr(expr); diff --git a/src/solvers/smt2_incremental/convert_expr_to_smt.cpp b/src/solvers/smt2_incremental/convert_expr_to_smt.cpp index 614a3659319b..3632147c0a8f 100644 --- a/src/solvers/smt2_incremental/convert_expr_to_smt.cpp +++ b/src/solvers/smt2_incremental/convert_expr_to_smt.cpp @@ -1469,6 +1469,15 @@ static smt_termt convert_expr_to_smt( count_trailing_zeros.pretty()); } +static smt_termt convert_expr_to_smt( + const zero_extend_exprt &zero_extend, + const sub_expression_mapt &converted) +{ + UNREACHABLE_BECAUSE( + "zero_extend expression should have been lowered by the decision " + "procedure before conversion to smt terms"); +} + static smt_termt convert_expr_to_smt( const prophecy_r_or_w_ok_exprt &prophecy_r_or_w_ok, const sub_expression_mapt &converted) @@ -1822,6 +1831,10 @@ static smt_termt dispatch_expr_to_smt_conversion( { return convert_expr_to_smt(*count_trailing_zeros, converted); } + if(const auto zero_extend = expr_try_dynamic_cast(expr)) + { + return convert_expr_to_smt(*zero_extend, converted); + } if( const auto prophecy_r_or_w_ok = expr_try_dynamic_cast(expr)) diff --git a/src/solvers/smt2_incremental/smt2_incremental_decision_procedure.cpp b/src/solvers/smt2_incremental/smt2_incremental_decision_procedure.cpp index bc78dfc171d5..72575d89f6b8 100644 --- a/src/solvers/smt2_incremental/smt2_incremental_decision_procedure.cpp +++ b/src/solvers/smt2_incremental/smt2_incremental_decision_procedure.cpp @@ -3,6 +3,7 @@ #include "smt2_incremental_decision_procedure.h" #include +#include #include #include #include @@ -296,6 +297,17 @@ static exprt lower_rw_ok_pointer_in_range(exprt expr, const namespacet &ns) return expr; } +static exprt lower_zero_extend(exprt expr, const namespacet &ns) +{ + expr.visit_pre([](exprt &expr) { + if(auto zero_extend = expr_try_dynamic_cast(expr)) + { + expr = zero_extend->lower(); + } + }); + return expr; +} + void smt2_incremental_decision_proceduret::ensure_handle_for_expr_defined( const exprt &in_expr) { @@ -677,8 +689,10 @@ void smt2_incremental_decision_proceduret::define_object_properties() exprt smt2_incremental_decision_proceduret::lower(exprt expression) const { - const exprt lowered = struct_encoding.encode(lower_enum( - lower_byte_operators(lower_rw_ok_pointer_in_range(expression, ns), ns), + const exprt lowered = struct_encoding.encode(lower_zero_extend( + lower_enum( + lower_byte_operators(lower_rw_ok_pointer_in_range(expression, ns), ns), + ns), ns)); log.conditional_output(log.debug(), [&](messaget::mstreamt &debug) { if(lowered != expression) diff --git a/src/util/bitvector_expr.cpp b/src/util/bitvector_expr.cpp index 940fa07a0b14..ac766d8ebee4 100644 --- a/src/util/bitvector_expr.cpp +++ b/src/util/bitvector_expr.cpp @@ -54,8 +54,7 @@ exprt update_bit_exprt::lower() const typecast_exprt(src(), src_bv_type), bitnot_exprt(mask_shifted)); // zero-extend the replacement bit to match src - auto new_value_casted = typecast_exprt( - typecast_exprt(new_value(), unsignedbv_typet(width)), src_bv_type); + auto new_value_casted = zero_extend_exprt{new_value(), src_bv_type}; // shift the replacement bits auto new_value_shifted = shl_exprt(new_value_casted, index()); @@ -85,7 +84,7 @@ exprt update_bits_exprt::lower() const bitand_exprt(typecast_exprt(src(), src_bv_type), mask_shifted); // zero-extend or shrink the replacement bits to match src - auto new_value_casted = typecast_exprt(new_value(), src_bv_type); + auto new_value_casted = zero_extend_exprt{new_value(), src_bv_type}; // shift the replacement bits auto new_value_shifted = shl_exprt(new_value_casted, index()); @@ -279,3 +278,19 @@ exprt find_first_set_exprt::lower() const return typecast_exprt::conditional_cast(result, type()); } + +exprt zero_extend_exprt::lower() const +{ + const auto old_width = to_bitvector_type(op().type()).get_width(); + const auto new_width = to_bitvector_type(type()).get_width(); + + if(new_width > old_width) + { + return concatenation_exprt{ + bv_typet{new_width - old_width}.all_zeros_expr(), op(), type()}; + } + else // new_width <= old_width + { + return extractbits_exprt{op(), integer_typet{}.zero_expr(), type()}; + } +} diff --git a/src/util/bitvector_expr.h b/src/util/bitvector_expr.h index 55a100c9bb5b..cf40a5af7643 100644 --- a/src/util/bitvector_expr.h +++ b/src/util/bitvector_expr.h @@ -1663,4 +1663,48 @@ inline find_first_set_exprt &to_find_first_set_expr(exprt &expr) return ret; } +/// \brief zero extension +/// The operand is converted to the given type by either +/// a) truncating if the new type is shorter, or +/// b) padding with most-significant zero bits if the new type is larger, or +/// c) reinterprets the operand as the given type if their widths match. +class zero_extend_exprt : public unary_exprt +{ +public: + zero_extend_exprt(exprt _op, typet _type) + : unary_exprt(ID_zero_extend, std::move(_op), std::move(_type)) + { + } + + // a lowering to extraction or concatenation + exprt lower() const; +}; + +template <> +inline bool can_cast_expr(const exprt &base) +{ + return base.id() == ID_zero_extend; +} + +/// \brief Cast an exprt to a \ref zero_extend_exprt +/// +/// \a expr must be known to be \ref zero_extend_exprt. +/// +/// \param expr: Source expression +/// \return Object of type \ref zero_extend_exprt +inline const zero_extend_exprt &to_zero_extend_expr(const exprt &expr) +{ + PRECONDITION(expr.id() == ID_zero_extend); + zero_extend_exprt::check(expr); + return static_cast(expr); +} + +/// \copydoc to_zero_extend_expr(const exprt &) +inline zero_extend_exprt &to_zero_extend_expr(exprt &expr) +{ + PRECONDITION(expr.id() == ID_zero_extend); + zero_extend_exprt::check(expr); + return static_cast(expr); +} + #endif // CPROVER_UTIL_BITVECTOR_EXPR_H diff --git a/src/util/format_expr.cpp b/src/util/format_expr.cpp index 436fc0540467..08ed7900d38b 100644 --- a/src/util/format_expr.cpp +++ b/src/util/format_expr.cpp @@ -376,6 +376,12 @@ void format_expr_configt::setup() << format(expr.type()) << ')'; }; + expr_map[ID_zero_extend] = + [](std::ostream &os, const exprt &expr) -> std::ostream & { + return os << "zero_extend(" << format(to_zero_extend_expr(expr).op()) + << ", " << format(expr.type()) << ')'; + }; + expr_map[ID_floatbv_typecast] = [](std::ostream &os, const exprt &expr) -> std::ostream & { const auto &floatbv_typecast_expr = to_floatbv_typecast_expr(expr); diff --git a/src/util/irep_ids.def b/src/util/irep_ids.def index f1728411191f..2582e750cd51 100644 --- a/src/util/irep_ids.def +++ b/src/util/irep_ids.def @@ -188,6 +188,7 @@ IREP_ID_ONE(extractbit) IREP_ID_ONE(extractbits) IREP_ID_ONE(update_bit) IREP_ID_ONE(update_bits) +IREP_ID_ONE(zero_extend) IREP_ID_TWO(C_reference, #reference) IREP_ID_TWO(C_rvalue_reference, #rvalue_reference) IREP_ID_ONE(true) diff --git a/src/util/lower_byte_operators.cpp b/src/util/lower_byte_operators.cpp index 701214d19362..2399796f6952 100644 --- a/src/util/lower_byte_operators.cpp +++ b/src/util/lower_byte_operators.cpp @@ -2491,15 +2491,16 @@ static exprt lower_byte_update( exprt zero_extended; if(bit_width > update_size_bits) { - zero_extended = concatenation_exprt{ - bv_typet{bit_width - update_size_bits}.all_zeros_expr(), - value, - bv_typet{bit_width}}; - - if(!is_little_endian) - to_concatenation_expr(zero_extended) - .op0() - .swap(to_concatenation_expr(zero_extended).op1()); + if(is_little_endian) + zero_extended = zero_extend_exprt{value, bv_typet{bit_width}}; + else + { + // Big endian -- the zero is added as LSB. + zero_extended = concatenation_exprt{ + value, + bv_typet{bit_width - update_size_bits}.all_zeros_expr(), + bv_typet{bit_width}}; + } } else zero_extended = value; diff --git a/src/util/simplify_expr.cpp b/src/util/simplify_expr.cpp index af6f3c55186d..f29fd8163ae0 100644 --- a/src/util/simplify_expr.cpp +++ b/src/util/simplify_expr.cpp @@ -3028,6 +3028,10 @@ simplify_exprt::resultt<> simplify_exprt::simplify_node(const exprt &node) { r = simplify_extractbits(to_extractbits_expr(expr)); } + else if(expr.id() == ID_zero_extend) + { + r = simplify_zero_extend(to_zero_extend_expr(expr)); + } else if(expr.id()==ID_ieee_float_equal || expr.id()==ID_ieee_float_notequal) { diff --git a/src/util/simplify_expr_class.h b/src/util/simplify_expr_class.h index b9b2181d678c..78c1fc4e71c5 100644 --- a/src/util/simplify_expr_class.h +++ b/src/util/simplify_expr_class.h @@ -76,6 +76,7 @@ class unary_overflow_exprt; class unary_plus_exprt; class update_exprt; class with_exprt; +class zero_extend_exprt; class simplify_exprt { @@ -152,6 +153,7 @@ class simplify_exprt [[nodiscard]] resultt<> simplify_extractbit(const extractbit_exprt &); [[nodiscard]] resultt<> simplify_extractbits(const extractbits_exprt &); [[nodiscard]] resultt<> simplify_concatenation(const concatenation_exprt &); + [[nodiscard]] resultt<> simplify_zero_extend(const zero_extend_exprt &); [[nodiscard]] resultt<> simplify_mult(const mult_exprt &); [[nodiscard]] resultt<> simplify_div(const div_exprt &); [[nodiscard]] resultt<> simplify_mod(const mod_exprt &); diff --git a/src/util/simplify_expr_int.cpp b/src/util/simplify_expr_int.cpp index 2087564a387a..e081a2ed0b4a 100644 --- a/src/util/simplify_expr_int.cpp +++ b/src/util/simplify_expr_int.cpp @@ -997,6 +997,18 @@ simplify_exprt::simplify_concatenation(const concatenation_exprt &expr) return std::move(new_expr); } +simplify_exprt::resultt<> +simplify_exprt::simplify_zero_extend(const zero_extend_exprt &expr) +{ + if(!can_cast_type(expr.type())) + return unchanged(expr); + + if(!can_cast_type(expr.op().type())) + return unchanged(expr); + + return changed(simplify_node(expr.lower())); +} + simplify_exprt::resultt<> simplify_exprt::simplify_shifts(const shift_exprt &expr) {