diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -1422,6 +1422,50 @@ OneToOneConvertToLLVMPattern; using XOrOpLowering = VectorConvertToLLVMPattern; +/// Lower `std.assert`. The default lowering calls the `abort` function if the +/// assertion is violated and has no effect otherwise. The failure message is +/// ignored by the default lowering but should be propagated by any custom +/// lowering. +struct AssertOpLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + AssertOp::Adaptor transformed(operands); + + // Insert the `abort` declaration if necessary. + auto module = op->getParentOfType(); + auto abortFunc = module.lookupSymbol("abort"); + if (!abortFunc) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(module.getBody()); + auto abortFuncTy = + LLVM::LLVMType::getFunctionTy(getVoidType(), {}, /*isVarArg=*/false); + abortFunc = rewriter.create(rewriter.getUnknownLoc(), + "abort", abortFuncTy); + } + + // Split block at `assert` operation. + Block *opBlock = rewriter.getInsertionBlock(); + auto opPosition = rewriter.getInsertionPoint(); + Block *continuationBlock = rewriter.splitBlock(opBlock, opPosition); + + // Generate IR to call `abort`. + Block *failureBlock = rewriter.createBlock(opBlock->getParent()); + rewriter.create(loc, abortFunc, llvm::None); + rewriter.create(loc); + + // Generate assertion test. + rewriter.setInsertionPointToEnd(opBlock); + rewriter.replaceOpWithNewOp( + op, transformed.arg(), continuationBlock, failureBlock); + + return success(); + } +}; + // Lowerings for operations on complex numbers. struct CreateComplexOpLowering @@ -3169,6 +3213,7 @@ AddIOpLowering, AllocaOpLowering, AndOpLowering, + AssertOpLowering, AtomicRMWOpLowering, BranchOpLowering, CallIndirectOpLowering, diff --git a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir --- a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir +++ b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir @@ -75,3 +75,21 @@ %0 = rsqrt %arg0 : vector<4x3xf32> std.return } + +// ----- + +// Lowers `assert` to a function call to `abort` if the assertion is violated. +// CHECK: llvm.func @abort() +// CHECK-LABEL: @assert_test_function +// CHECK-SAME: (%[[ARG:.*]]: !llvm.i1) +func @assert_test_function(%arg : i1) { + // CHECK: llvm.cond_br %[[ARG]], ^[[CONTINUATION_BLOCK:.*]], ^[[FAILURE_BLOCK:.*]] + // CHECK: ^[[CONTINUATION_BLOCK]]: + // CHECK: llvm.return + // CHECK: ^[[FAILURE_BLOCK]]: + // CHECK: llvm.call @abort() : () -> () + // CHECK: llvm.unreachable + assert %arg, "Computer says no" + return +} +