diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -1399,6 +1399,51 @@ OneToOneConvertToLLVMPattern; using XOrOpLowering = VectorConvertToLLVMPattern; +/// Lower `std.assert`. The default lowering calls the `abort` function if the +/// assertion is violated and has no effect otherwise. The failure message is +/// ignored by the default lowering but should be propagated by any custom +/// lowering. + +struct AssertOpLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + AssertOp::Adaptor transformed(operands); + + // Insert the `abort` declaration if necessary. + auto module = op->getParentOfType(); + auto abortFunc = module.lookupSymbol("abort"); + if (!abortFunc) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(module.getBody()); + auto abortFuncTy = + LLVM::LLVMType::getFunctionTy(getVoidType(), {}, /*isVarArg=*/false); + abortFunc = rewriter.create(rewriter.getUnknownLoc(), + "abort", abortFuncTy); + } + + // Split block at `assert` operation. + auto *opBlock = rewriter.getInsertionBlock(); + auto opPosition = rewriter.getInsertionPoint(); + auto *continuationBlock = rewriter.splitBlock(opBlock, opPosition); + + // Generate IR to call `abort`. + auto *failureBlock = rewriter.createBlock(opBlock->getParent()); + rewriter.create(loc, abortFunc, llvm::None); + rewriter.create(loc); + + // Generate assertion test. + rewriter.setInsertionPointToEnd(opBlock); + rewriter.replaceOpWithNewOp(op, transformed.arg(), + failureBlock, continuationBlock); + + return success(); + } +}; + // Lowerings for operations on complex numbers. struct CreateComplexOpLowering @@ -3146,6 +3191,7 @@ AddIOpLowering, AllocaOpLowering, AndOpLowering, + AssertOpLowering, AtomicRMWOpLowering, BranchOpLowering, CallIndirectOpLowering, diff --git a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir --- a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir +++ b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir @@ -87,3 +87,21 @@ // expected-error@+1 {{must be LLVM dialect type}} return %1 : i32 } + +// ----- + +// Lowers `assert` to a function call to `abort` if the assertion is violated. +// CHECK: llvm.func @abort() +// CHECK-LABEL: @assert_test_function +// CHECK-SAME: (%[[ARG:.*]]: !llvm.i1) +func @assert_test_function(%arg : i1) { + // CHECK: llvm.cond_br %[[ARG]], ^[[FAILURE_BLOCK:.*]], ^[[CONTINUATION_BLOCK:.*]] + // CHEC-K: ^[[CONTINUATION_BLOCK]]: + // CHECK: llvm.return + // CHEC-K: ^[[FAILURE_BLOCK]]: + // CHECK: llvm.call @abort() : () -> () + // CHECK: llvm.unreachable + assert %arg, "Computer says no" + return +} +