Skip to content

Commit

Permalink
[CombToSMT] Make result of div-by-zero undefined (#7025)
Browse files Browse the repository at this point in the history
This adapts the conversion pass to match the recently agreed upon definition for division by zero. Integration tests for circt-lec are added to check the behavior. Note that two syntactically equivalent modules are not considered equivalent if they aren't guaranteed to deterministically produce the same outputs. Alternatively, we could consider two undefined output values equivalent by modeling each value as a pair of a boolean and the bit-vector where the boolean determines if the value is undefined, then two outputs are equivalent if either the boolean is true or the boolean is false and the bitvectors match. There are probably use-cases for both, so maybe we'd want a flag to let the user decide.
  • Loading branch information
maerhart authored May 14, 2024
1 parent b51a644 commit 481cb60
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 39 deletions.
72 changes: 68 additions & 4 deletions integration_test/circt-lec/comb.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,42 @@ hw.module @decomposedAnd(in %in1: i1, in %in2: i1, out out: i1) {
// TODO

// comb.divs
// TODO
// RUN: circt-lec %s -c1=divs_unsafe -c2=divs_unsafe --shared-libs=%libz3 | FileCheck %s --check-prefix=COMB_DIVS_UNSAFE
// RUN: circt-lec %s -c1=divs -c2=divs --shared-libs=%libz3 | FileCheck %s --check-prefix=COMB_DIVS
// COMB_DIVS_UNSAFE: c1 != c2
// COMB_DIVS: c1 == c2

hw.module @divs_unsafe(in %in1: i32, in %in2: i32, out out: i32) {
%0 = comb.divs %in1, %in2 : i32
hw.output %0 : i32
}

hw.module @divs(in %in1: i32, in %in2: i32, out out: i32) {
%0 = hw.constant 0 : i32
%1 = comb.icmp eq %in2, %0 : i32
%2 = comb.divs %in1, %in2 : i32
%3 = comb.mux %1, %0, %2 : i32
hw.output %3 : i32
}

// comb.divu
// TODO
// RUN: circt-lec %s -c1=divu_unsafe -c2=divu_unsafe --shared-libs=%libz3 | FileCheck %s --check-prefix=COMB_DIVU_UNSAFE
// RUN: circt-lec %s -c1=divu -c2=divu --shared-libs=%libz3 | FileCheck %s --check-prefix=COMB_DIVU
// COMB_DIVU_UNSAFE: c1 != c2
// COMB_DIVU: c1 == c2

hw.module @divu_unsafe(in %in1: i32, in %in2: i32, out out: i32) {
%0 = comb.divu %in1, %in2 : i32
hw.output %0 : i32
}

hw.module @divu(in %in1: i32, in %in2: i32, out out: i32) {
%0 = hw.constant 0 : i32
%1 = comb.icmp eq %in2, %0 : i32
%2 = comb.divu %in1, %in2 : i32
%3 = comb.mux %1, %0, %2 : i32
hw.output %3 : i32
}

// comb.extract
// TODO
Expand All @@ -85,10 +117,42 @@ hw.module @constFalse(in %a: i8, out eq: i1) {
// TODO: Other icmp predicates

// comb.mods
// TODO
// RUN: circt-lec %s -c1=mods_unsafe -c2=mods_unsafe --shared-libs=%libz3 | FileCheck %s --check-prefix=COMB_MODS_UNSAFE
// RUN: circt-lec %s -c1=mods -c2=mods --shared-libs=%libz3 | FileCheck %s --check-prefix=COMB_MODS
// COMB_MODS_UNSAFE: c1 != c2
// COMB_MODS: c1 == c2

hw.module @mods_unsafe(in %in1: i32, in %in2: i32, out out: i32) {
%0 = comb.mods %in1, %in2 : i32
hw.output %0 : i32
}

hw.module @mods(in %in1: i32, in %in2: i32, out out: i32) {
%0 = hw.constant 0 : i32
%1 = comb.icmp eq %in2, %0 : i32
%2 = comb.mods %in1, %in2 : i32
%3 = comb.mux %1, %0, %2 : i32
hw.output %3 : i32
}

// comb.modu
// TODO
// RUN: circt-lec %s -c1=modu_unsafe -c2=modu_unsafe --shared-libs=%libz3 | FileCheck %s --check-prefix=COMB_MODU_UNSAFE
// RUN: circt-lec %s -c1=modu -c2=modu --shared-libs=%libz3 | FileCheck %s --check-prefix=COMB_MODU
// COMB_MODU_UNSAFE: c1 != c2
// COMB_MODU: c1 == c2

hw.module @modu_unsafe(in %in1: i32, in %in2: i32, out out: i32) {
%0 = comb.modu %in1, %in2 : i32
hw.output %0 : i32
}

hw.module @modu(in %in1: i32, in %in2: i32, out out: i32) {
%0 = hw.constant 0 : i32
%1 = comb.icmp eq %in2, %0 : i32
%2 = comb.modu %in1, %in2 : i32
%3 = comb.mux %1, %0, %2 : i32
hw.output %3 : i32
}

// comb.mul
// RUN: circt-lec %s -c1=mulBy2 -c2=addTwice --shared-libs=%libz3 | FileCheck %s --check-prefix=COMB_MUL
Expand Down
36 changes: 32 additions & 4 deletions lib/Conversion/CombToSMT/CombToSMT.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,34 @@ struct OneToOneOpConversion : OpConversionPattern<SourceOp> {
}
};

/// Lower the SourceOp to the TargetOp special-casing if the second operand is
/// zero to return a new symbolic value.
template <typename SourceOp, typename TargetOp>
struct DivisionOpConversion : OpConversionPattern<SourceOp> {
using OpConversionPattern<SourceOp>::OpConversionPattern;
using OpAdaptor = typename SourceOp::Adaptor;

LogicalResult
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
auto type = dyn_cast<smt::BitVectorType>(adaptor.getRhs().getType());
if (!type)
return failure();

auto resultType = OpConversionPattern<SourceOp>::typeConverter->convertType(
op.getResult().getType());
Value zero =
rewriter.create<smt::BVConstantOp>(loc, APInt(type.getWidth(), 0));
Value isZero = rewriter.create<smt::EqOp>(loc, adaptor.getRhs(), zero);
Value symbolicVal = rewriter.create<smt::DeclareFunOp>(loc, resultType);
Value division =
rewriter.create<TargetOp>(loc, resultType, adaptor.getOperands());
rewriter.replaceOpWithNewOp<smt::IteOp>(op, isZero, symbolicVal, division);
return success();
}
};

/// Converts an operation with a variadic number of operands to a chain of
/// binary operations assuming left-associativity of the operation.
template <typename SourceOp, typename TargetOp>
Expand Down Expand Up @@ -236,10 +264,10 @@ void circt::populateCombToSMTConversionPatterns(TypeConverter &converter,
OneToOneOpConversion<ShlOp, smt::BVShlOp>,
OneToOneOpConversion<ShrUOp, smt::BVLShrOp>,
OneToOneOpConversion<ShrSOp, smt::BVAShrOp>,
OneToOneOpConversion<DivSOp, smt::BVSDivOp>,
OneToOneOpConversion<DivUOp, smt::BVUDivOp>,
OneToOneOpConversion<ModSOp, smt::BVSRemOp>,
OneToOneOpConversion<ModUOp, smt::BVURemOp>,
DivisionOpConversion<DivSOp, smt::BVSDivOp>,
DivisionOpConversion<DivUOp, smt::BVUDivOp>,
DivisionOpConversion<ModSOp, smt::BVSRemOp>,
DivisionOpConversion<ModUOp, smt::BVURemOp>,
VariadicToBinaryOpConversion<ConcatOp, smt::ConcatOp>,
VariadicToBinaryOpConversion<AddOp, smt::BVAddOp>,
VariadicToBinaryOpConversion<MulOp, smt::BVMulOp>,
Expand Down
49 changes: 22 additions & 27 deletions lib/Tools/circt-lec/ConstructLEC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,37 +110,32 @@ void ConstructLECPass::runOnOperation() {

builder.createBlock(&entryFunc.getBody());

Value areEquivalent;
if (moduleA == moduleB) {
// Trivially equivalent
areEquivalent =
builder.create<LLVM::ConstantOp>(loc, builder.getI1Type(), 1);
moduleA->erase();
} else {
auto lecOp = builder.create<verif::LogicEquivalenceCheckingOp>(loc);
areEquivalent = lecOp.getAreEquivalent();
auto *outputOpA = moduleA.getBodyBlock()->getTerminator();
auto *outputOpB = moduleB.getBodyBlock()->getTerminator();
lecOp.getFirstCircuit().takeBody(moduleA.getBody());
lecOp.getSecondCircuit().takeBody(moduleB.getBody());

moduleA->erase();
auto lecOp = builder.create<verif::LogicEquivalenceCheckingOp>(loc);
Value areEquivalent = lecOp.getAreEquivalent();
builder.cloneRegionBefore(moduleA.getBody(), lecOp.getFirstCircuit(),
lecOp.getFirstCircuit().end());
builder.cloneRegionBefore(moduleB.getBody(), lecOp.getSecondCircuit(),
lecOp.getSecondCircuit().end());

moduleA->erase();
if (moduleA != moduleB)
moduleB->erase();

{
OpBuilder::InsertionGuard guard(builder);
builder.setInsertionPoint(outputOpA);
builder.create<verif::YieldOp>(loc, outputOpA->getOperands());
outputOpA->erase();
builder.setInsertionPoint(outputOpB);
builder.create<verif::YieldOp>(loc, outputOpB->getOperands());
outputOpB->erase();
}

sortTopologically(&lecOp.getFirstCircuit().front());
sortTopologically(&lecOp.getSecondCircuit().front());
{
auto *term = lecOp.getFirstCircuit().front().getTerminator();
OpBuilder::InsertionGuard guard(builder);
builder.setInsertionPoint(term);
builder.create<verif::YieldOp>(loc, term->getOperands());
term->erase();
term = lecOp.getSecondCircuit().front().getTerminator();
builder.setInsertionPoint(term);
builder.create<verif::YieldOp>(loc, term->getOperands());
term->erase();
}

sortTopologically(&lecOp.getFirstCircuit().front());
sortTopologically(&lecOp.getSecondCircuit().front());

// TODO: we should find a more elegant way of reporting the result than
// already inserting some LLVM here
Value eqFormatString =
Expand Down
24 changes: 20 additions & 4 deletions test/Conversion/CombToSMT/comb-to-smt.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,29 @@ func.func @test(%a0: !smt.bv<32>, %a1: !smt.bv<32>, %a2: !smt.bv<32>, %a3: !smt.
%arg4 = builtin.unrealized_conversion_cast %a4 : !smt.bv<1> to i1
%arg5 = builtin.unrealized_conversion_cast %a5 : !smt.bv<4> to i4

// CHECK: smt.bv.sdiv [[A0]], [[A1]] : !smt.bv<32>
// CHECK: [[ZERO:%.+]] = smt.bv.constant #smt.bv<0> : !smt.bv<32>
// CHECK-NEXT: [[IS_ZERO:%.+]] = smt.eq [[A1]], [[ZERO]] : !smt.bv<32>
// CHECK-NEXT: [[UNDEF:%.+]] = smt.declare_fun : !smt.bv<32>
// CHECK-NEXT: [[DIV:%.+]] = smt.bv.sdiv [[A0]], [[A1]] : !smt.bv<32>
// CHECK-NEXT: smt.ite [[IS_ZERO]], [[UNDEF]], [[DIV]] : !smt.bv<32>
%0 = comb.divs %arg0, %arg1 : i32
// CHECK-NEXT: smt.bv.udiv [[A0]], [[A1]] : !smt.bv<32>
// CHECK-NEXT: [[ZERO:%.+]] = smt.bv.constant #smt.bv<0> : !smt.bv<32>
// CHECK-NEXT: [[IS_ZERO:%.+]] = smt.eq [[A1]], [[ZERO]] : !smt.bv<32>
// CHECK-NEXT: [[UNDEF:%.+]] = smt.declare_fun : !smt.bv<32>
// CHECK-NEXT: [[DIV:%.+]] = smt.bv.udiv [[A0]], [[A1]] : !smt.bv<32>
// CHECK-NEXT: smt.ite [[IS_ZERO]], [[UNDEF]], [[DIV]] : !smt.bv<32>
%1 = comb.divu %arg0, %arg1 : i32
// CHECK-NEXT: smt.bv.srem [[A0]], [[A1]] : !smt.bv<32>
// CHECK-NEXT: [[ZERO:%.+]] = smt.bv.constant #smt.bv<0> : !smt.bv<32>
// CHECK-NEXT: [[IS_ZERO:%.+]] = smt.eq [[A1]], [[ZERO]] : !smt.bv<32>
// CHECK-NEXT: [[UNDEF:%.+]] = smt.declare_fun : !smt.bv<32>
// CHECK-NEXT: [[DIV:%.+]] = smt.bv.srem [[A0]], [[A1]] : !smt.bv<32>
// CHECK-NEXT: smt.ite [[IS_ZERO]], [[UNDEF]], [[DIV]] : !smt.bv<32>
%2 = comb.mods %arg0, %arg1 : i32
// CHECK-NEXT: smt.bv.urem [[A0]], [[A1]] : !smt.bv<32>
// CHECK-NEXT: [[ZERO:%.+]] = smt.bv.constant #smt.bv<0> : !smt.bv<32>
// CHECK-NEXT: [[IS_ZERO:%.+]] = smt.eq [[A1]], [[ZERO]] : !smt.bv<32>
// CHECK-NEXT: [[UNDEF:%.+]] = smt.declare_fun : !smt.bv<32>
// CHECK-NEXT: [[DIV:%.+]] = smt.bv.urem [[A0]], [[A1]] : !smt.bv<32>
// CHECK-NEXT: smt.ite [[IS_ZERO]], [[UNDEF]], [[DIV]] : !smt.bv<32>
%3 = comb.modu %arg0, %arg1 : i32

// CHECK-NEXT: [[NEG:%.+]] = smt.bv.neg [[A1]] : !smt.bv<32>
Expand Down

0 comments on commit 481cb60

Please sign in to comment.