diff --git a/mlir/include/mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h b/mlir/include/mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h --- a/mlir/include/mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h +++ b/mlir/include/mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h @@ -27,8 +27,13 @@ /// Collect the patterns to convert from the ControlFlow dialect to LLVM. The /// conversion patterns capture the LLVMTypeConverter by reference meaning the /// references have to remain alive during the entire pattern lifetime. -void populateControlFlowToLLVMConversionPatterns(LLVMTypeConverter &converter, - RewritePatternSet &patterns); +/// +/// If `abortOnFailedAssert` is false, messages of failed assertions are +/// printed, but program execution continues. This is useful for testing +/// asserts. +void populateControlFlowToLLVMConversionPatterns( + LLVMTypeConverter &converter, RewritePatternSet &patterns, + bool abortOnFailedAssert = true); /// Creates a pass to convert the ControlFlow dialect into the LLVMIR dialect. std::unique_ptr createConvertControlFlowToLLVMPass(); diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -248,6 +248,9 @@ Option<"indexBitwidth", "index-bitwidth", "unsigned", /*default=kDeriveIndexBitwidthFromDataLayout*/"0", "Bitwidth of the index type, 0 to use size of machine word">, + Option<"abortOnFailedAssert", "abort-on-failed-assert", "bool", + /*default=*/"true", + "Abort program when an assertion is failing"> ]; } diff --git a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h --- a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h +++ b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h @@ -36,6 +36,7 @@ LLVM::LLVMFuncOp lookupOrCreatePrintU64Fn(ModuleOp moduleOp); LLVM::LLVMFuncOp lookupOrCreatePrintF32Fn(ModuleOp moduleOp); LLVM::LLVMFuncOp lookupOrCreatePrintF64Fn(ModuleOp moduleOp); +LLVM::LLVMFuncOp lookupOrCreatePrintStrFn(ModuleOp moduleOp); LLVM::LLVMFuncOp lookupOrCreatePrintOpenFn(ModuleOp moduleOp); LLVM::LLVMFuncOp lookupOrCreatePrintCloseFn(ModuleOp moduleOp); LLVM::LLVMFuncOp lookupOrCreatePrintCommaFn(ModuleOp moduleOp); diff --git a/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h b/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h --- a/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h +++ b/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h @@ -470,6 +470,11 @@ extern "C" MLIR_CRUNNERUTILS_EXPORT void printComma(); extern "C" MLIR_CRUNNERUTILS_EXPORT void printNewline(); +//===----------------------------------------------------------------------===// +// Small runtime support library for debug assertions. +//===----------------------------------------------------------------------===// +extern "C" MLIR_CRUNNERUTILS_EXPORT void printStr(int8_t *str); + //===----------------------------------------------------------------------===// // Small runtime support library for timing execution and printing GFLOPS //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp --- a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp +++ b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp @@ -36,38 +36,88 @@ #define PASS_NAME "convert-cf-to-llvm" namespace { +std::string generateGlobalMsgSymbolName(ModuleOp moduleOp) { + std::string prefix = "assert_msg_"; + int counter = 0; + while (moduleOp.lookupSymbol(prefix + std::to_string(counter))) + ++counter; + return prefix + std::to_string(counter); +} + +/// Generate IR that prints the given string to stderr. +void printMsg(OpBuilder &builder, Location loc, ModuleOp moduleOp, + StringRef msg) { + auto ip = builder.saveInsertionPoint(); + builder.setInsertionPointToStart(moduleOp.getBody()); + MLIRContext *ctx = builder.getContext(); + + // Create a zero-terminated byte representation and allocate global symbol. + SmallVector elementVals; + elementVals.append(msg.begin(), msg.end()); + elementVals.push_back('\n'); + elementVals.push_back(0); + auto dataAttrType = RankedTensorType::get( + {static_cast(elementVals.size())}, builder.getI8Type()); + auto dataAttr = + DenseElementsAttr::get(dataAttrType, llvm::makeArrayRef(elementVals)); + auto arrayTy = + LLVM::LLVMArrayType::get(IntegerType::get(ctx, 8), elementVals.size()); + std::string symbolName = generateGlobalMsgSymbolName(moduleOp); + auto globalOp = builder.create( + loc, arrayTy, /*constant=*/true, LLVM::Linkage::Private, symbolName, + dataAttr); + + // Emit call to `printStr` in runtime library. + builder.restoreInsertionPoint(ip); + auto msgAddr = builder.create( + loc, LLVM::LLVMPointerType::get(arrayTy), globalOp.getName()); + SmallVector indices(1, 0); + Value gep = builder.create( + loc, LLVM::LLVMPointerType::get(builder.getI8Type()), msgAddr, indices); + Operation *printer = LLVM::lookupOrCreatePrintStrFn(moduleOp); + builder.create(loc, TypeRange(), SymbolRefAttr::get(printer), + gep); +} + /// Lower `cf.assert`. The default lowering calls the `abort` function if the /// assertion is violated and has no effect otherwise. The failure message is /// ignored by the default lowering but should be propagated by any custom /// lowering. struct AssertOpLowering : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + explicit AssertOpLowering(LLVMTypeConverter &typeConverter, + bool abortOnFailedAssert = true) + : ConvertOpToLLVMPattern(typeConverter, /*benefit=*/1), + abortOnFailedAssert(abortOnFailedAssert) {} LogicalResult matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); - - // Insert the `abort` declaration if necessary. auto module = op->getParentOfType(); - auto abortFunc = module.lookupSymbol("abort"); - if (!abortFunc) { - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToStart(module.getBody()); - auto abortFuncTy = LLVM::LLVMFunctionType::get(getVoidType(), {}); - abortFunc = rewriter.create(rewriter.getUnknownLoc(), - "abort", abortFuncTy); - } // Split block at `assert` operation. Block *opBlock = rewriter.getInsertionBlock(); auto opPosition = rewriter.getInsertionPoint(); Block *continuationBlock = rewriter.splitBlock(opBlock, opPosition); - // Generate IR to call `abort`. + // Failed block: Generate IR to print the message and call `abort`. Block *failureBlock = rewriter.createBlock(opBlock->getParent()); - rewriter.create(loc, abortFunc, llvm::None); - rewriter.create(loc); + printMsg(rewriter, loc, module, op.getMsg()); + if (abortOnFailedAssert) { + // Insert the `abort` declaration if necessary. + auto abortFunc = module.lookupSymbol("abort"); + if (!abortFunc) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(module.getBody()); + auto abortFuncTy = LLVM::LLVMFunctionType::get(getVoidType(), {}); + abortFunc = rewriter.create(rewriter.getUnknownLoc(), + "abort", abortFuncTy); + } + rewriter.create(loc, abortFunc, llvm::None); + rewriter.create(loc); + } else { + rewriter.create(loc, ValueRange(), continuationBlock); + } // Generate assertion test. rewriter.setInsertionPointToEnd(opBlock); @@ -76,6 +126,11 @@ return success(); } + +private: + /// If set to `false`, messages are printed but program execution continues. + /// This is useful for testing asserts. + bool abortOnFailedAssert = true; }; /// The cf->LLVM lowerings for branching ops require that the blocks they jump @@ -185,14 +240,15 @@ } // namespace void mlir::cf::populateControlFlowToLLVMConversionPatterns( - LLVMTypeConverter &converter, RewritePatternSet &patterns) { + LLVMTypeConverter &converter, RewritePatternSet &patterns, + bool abortOnFailedAssert) { // clang-format off patterns.add< - AssertOpLowering, BranchOpLowering, CondBranchOpLowering, SwitchOpLowering>(converter); // clang-format on + patterns.add(converter, abortOnFailedAssert); } //===----------------------------------------------------------------------===// @@ -215,7 +271,8 @@ options.overrideIndexBitwidth(indexBitwidth); LLVMTypeConverter converter(&getContext(), options); - mlir::cf::populateControlFlowToLLVMConversionPatterns(converter, patterns); + mlir::cf::populateControlFlowToLLVMConversionPatterns(converter, patterns, + abortOnFailedAssert); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) diff --git a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp @@ -28,6 +28,7 @@ static constexpr llvm::StringRef kPrintU64 = "printU64"; static constexpr llvm::StringRef kPrintF32 = "printF32"; static constexpr llvm::StringRef kPrintF64 = "printF64"; +static constexpr llvm::StringRef kPrintStr = "printStr"; static constexpr llvm::StringRef kPrintOpen = "printOpen"; static constexpr llvm::StringRef kPrintClose = "printClose"; static constexpr llvm::StringRef kPrintComma = "printComma"; @@ -78,6 +79,13 @@ LLVM::LLVMVoidType::get(moduleOp->getContext())); } +LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintStrFn(ModuleOp moduleOp) { + return lookupOrCreateFn( + moduleOp, kPrintStr, + LLVM::LLVMPointerType::get(IntegerType::get(moduleOp->getContext(), 8)), + LLVM::LLVMVoidType::get(moduleOp->getContext())); +} + LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintOpenFn(ModuleOp moduleOp) { return lookupOrCreateFn(moduleOp, kPrintOpen, {}, LLVM::LLVMVoidType::get(moduleOp->getContext())); diff --git a/mlir/lib/ExecutionEngine/CRunnerUtils.cpp b/mlir/lib/ExecutionEngine/CRunnerUtils.cpp --- a/mlir/lib/ExecutionEngine/CRunnerUtils.cpp +++ b/mlir/lib/ExecutionEngine/CRunnerUtils.cpp @@ -46,6 +46,11 @@ extern "C" void printComma() { fputs(", ", stdout); } extern "C" void printNewline() { fputc('\n', stdout); } +// For debug assertions only: Print the given string. +extern "C" void printStr(int8_t *str) { + fputs(reinterpret_cast(str), stderr); +} + extern "C" void memrefCopy(int64_t elemSize, UnrankedMemRefType *srcArg, UnrankedMemRefType *dstArg) { DynamicMemRefType src(*srcArg); diff --git a/mlir/test/Integration/Dialect/ControlFlow/assert.mlir b/mlir/test/Integration/Dialect/ControlFlow/assert.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/Dialect/ControlFlow/assert.mlir @@ -0,0 +1,15 @@ +// RUN: mlir-opt %s -convert-cf-to-llvm="abort-on-failed-assert=0" \ +// RUN: -convert-func-to-llvm | \ +// RUN: mlir-cpu-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_lib_dir/libmlir_runner_utils%shlibext 2>&1 | \ +// RUN: FileCheck %s + +func.func @main() { + %a = arith.constant 0 : i1 + %b = arith.constant 1 : i1 + // CHECK: assertion foo + cf.assert %a, "assertion foo" + // CHECK-NOT: assertion bar + cf.assert %b, "assertion bar" + return +}