diff --git a/mlir/include/mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h b/mlir/include/mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h --- a/mlir/include/mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h +++ b/mlir/include/mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h @@ -30,6 +30,12 @@ void populateControlFlowToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns); +/// Populate the cf.assert to LLVM conversion pattern. If `abortOnFailure` is +/// unset, the program execution continues when a condition is unsatisfied. +void populateAssertToLLVMConversionPattern(LLVMTypeConverter &converter, + RewritePatternSet &patterns, + bool abortOnFailure = true); + /// Creates a pass to convert the ControlFlow dialect into the LLVMIR dialect. std::unique_ptr createConvertControlFlowToLLVMPass(); } // namespace cf diff --git a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h --- a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h +++ b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h @@ -36,6 +36,7 @@ LLVM::LLVMFuncOp lookupOrCreatePrintU64Fn(ModuleOp moduleOp); LLVM::LLVMFuncOp lookupOrCreatePrintF32Fn(ModuleOp moduleOp); LLVM::LLVMFuncOp lookupOrCreatePrintF64Fn(ModuleOp moduleOp); +LLVM::LLVMFuncOp lookupOrCreatePrintStrFn(ModuleOp moduleOp); LLVM::LLVMFuncOp lookupOrCreatePrintOpenFn(ModuleOp moduleOp); LLVM::LLVMFuncOp lookupOrCreatePrintCloseFn(ModuleOp moduleOp); LLVM::LLVMFuncOp lookupOrCreatePrintCommaFn(ModuleOp moduleOp); diff --git a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp --- a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp +++ b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp @@ -35,39 +35,89 @@ #define PASS_NAME "convert-cf-to-llvm" +static std::string generateGlobalMsgSymbolName(ModuleOp moduleOp) { + std::string prefix = "assert_msg_"; + int counter = 0; + while (moduleOp.lookupSymbol(prefix + std::to_string(counter))) + ++counter; + return prefix + std::to_string(counter); +} + +/// Generate IR that prints the given string to stderr. +static void createPrintMsg(OpBuilder &builder, Location loc, ModuleOp moduleOp, + StringRef msg) { + auto ip = builder.saveInsertionPoint(); + builder.setInsertionPointToStart(moduleOp.getBody()); + MLIRContext *ctx = builder.getContext(); + + // Create a zero-terminated byte representation and allocate global symbol. + SmallVector elementVals; + elementVals.append(msg.begin(), msg.end()); + elementVals.push_back('\n'); + elementVals.push_back(0); + auto dataAttrType = RankedTensorType::get( + {static_cast(elementVals.size())}, builder.getI8Type()); + auto dataAttr = + DenseElementsAttr::get(dataAttrType, llvm::makeArrayRef(elementVals)); + auto arrayTy = + LLVM::LLVMArrayType::get(IntegerType::get(ctx, 8), elementVals.size()); + std::string symbolName = generateGlobalMsgSymbolName(moduleOp); + auto globalOp = builder.create( + loc, arrayTy, /*constant=*/true, LLVM::Linkage::Private, symbolName, + dataAttr); + + // Emit call to `printStr` in runtime library. + builder.restoreInsertionPoint(ip); + auto msgAddr = builder.create( + loc, LLVM::LLVMPointerType::get(arrayTy), globalOp.getName()); + SmallVector indices(1, 0); + Value gep = builder.create( + loc, LLVM::LLVMPointerType::get(builder.getI8Type()), msgAddr, indices); + Operation *printer = LLVM::lookupOrCreatePrintStrFn(moduleOp); + builder.create(loc, TypeRange(), SymbolRefAttr::get(printer), + gep); +} + namespace { /// Lower `cf.assert`. The default lowering calls the `abort` function if the /// assertion is violated and has no effect otherwise. The failure message is /// ignored by the default lowering but should be propagated by any custom /// lowering. struct AssertOpLowering : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + explicit AssertOpLowering(LLVMTypeConverter &typeConverter, + bool abortOnFailedAssert = true) + : ConvertOpToLLVMPattern(typeConverter, /*benefit=*/1), + abortOnFailedAssert(abortOnFailedAssert) {} LogicalResult matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); - - // Insert the `abort` declaration if necessary. auto module = op->getParentOfType(); - auto abortFunc = module.lookupSymbol("abort"); - if (!abortFunc) { - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToStart(module.getBody()); - auto abortFuncTy = LLVM::LLVMFunctionType::get(getVoidType(), {}); - abortFunc = rewriter.create(rewriter.getUnknownLoc(), - "abort", abortFuncTy); - } // Split block at `assert` operation. Block *opBlock = rewriter.getInsertionBlock(); auto opPosition = rewriter.getInsertionPoint(); Block *continuationBlock = rewriter.splitBlock(opBlock, opPosition); - // Generate IR to call `abort`. + // Failed block: Generate IR to print the message and call `abort`. Block *failureBlock = rewriter.createBlock(opBlock->getParent()); - rewriter.create(loc, abortFunc, llvm::None); - rewriter.create(loc); + createPrintMsg(rewriter, loc, module, op.getMsg()); + if (abortOnFailedAssert) { + // Insert the `abort` declaration if necessary. + auto abortFunc = module.lookupSymbol("abort"); + if (!abortFunc) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(module.getBody()); + auto abortFuncTy = LLVM::LLVMFunctionType::get(getVoidType(), {}); + abortFunc = rewriter.create(rewriter.getUnknownLoc(), + "abort", abortFuncTy); + } + rewriter.create(loc, abortFunc, llvm::None); + rewriter.create(loc); + } else { + rewriter.create(loc, ValueRange(), continuationBlock); + } // Generate assertion test. rewriter.setInsertionPointToEnd(opBlock); @@ -76,6 +126,11 @@ return success(); } + +private: + /// If set to `false`, messages are printed but program execution continues. + /// This is useful for testing asserts. + bool abortOnFailedAssert = true; }; /// The cf->LLVM lowerings for branching ops require that the blocks they jump @@ -195,6 +250,12 @@ // clang-format on } +void mlir::cf::populateAssertToLLVMConversionPattern( + LLVMTypeConverter &converter, RewritePatternSet &patterns, + bool abortOnFailure) { + patterns.add(converter, abortOnFailure); +} + //===----------------------------------------------------------------------===// // Pass Definition //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp @@ -28,6 +28,7 @@ static constexpr llvm::StringRef kPrintU64 = "printU64"; static constexpr llvm::StringRef kPrintF32 = "printF32"; static constexpr llvm::StringRef kPrintF64 = "printF64"; +static constexpr llvm::StringRef kPrintStr = "puts"; static constexpr llvm::StringRef kPrintOpen = "printOpen"; static constexpr llvm::StringRef kPrintClose = "printClose"; static constexpr llvm::StringRef kPrintComma = "printComma"; @@ -78,6 +79,13 @@ LLVM::LLVMVoidType::get(moduleOp->getContext())); } +LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintStrFn(ModuleOp moduleOp) { + return lookupOrCreateFn( + moduleOp, kPrintStr, + LLVM::LLVMPointerType::get(IntegerType::get(moduleOp->getContext(), 8)), + LLVM::LLVMVoidType::get(moduleOp->getContext())); +} + LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintOpenFn(ModuleOp moduleOp) { return lookupOrCreateFn(moduleOp, kPrintOpen, {}, LLVM::LLVMVoidType::get(moduleOp->getContext())); diff --git a/mlir/test/Integration/Dialect/ControlFlow/assert.mlir b/mlir/test/Integration/Dialect/ControlFlow/assert.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/Dialect/ControlFlow/assert.mlir @@ -0,0 +1,15 @@ +// RUN: mlir-opt %s -test-cf-assert \ +// RUN: -convert-func-to-llvm | \ +// RUN: mlir-cpu-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_lib_dir/libmlir_runner_utils%shlibext | \ +// RUN: FileCheck %s + +func.func @main() { + %a = arith.constant 0 : i1 + %b = arith.constant 1 : i1 + // CHECK: assertion foo + cf.assert %a, "assertion foo" + // CHECK-NOT: assertion bar + cf.assert %b, "assertion bar" + return +} diff --git a/mlir/test/lib/Dialect/CMakeLists.txt b/mlir/test/lib/Dialect/CMakeLists.txt --- a/mlir/test/lib/Dialect/CMakeLists.txt +++ b/mlir/test/lib/Dialect/CMakeLists.txt @@ -1,5 +1,6 @@ add_subdirectory(Affine) add_subdirectory(Arith) +add_subdirectory(ControlFlow) add_subdirectory(DLTI) add_subdirectory(Func) add_subdirectory(GPU) diff --git a/mlir/test/lib/Dialect/ControlFlow/CMakeLists.txt b/mlir/test/lib/Dialect/ControlFlow/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Dialect/ControlFlow/CMakeLists.txt @@ -0,0 +1,14 @@ +# Exclude tests from libMLIR.so +add_mlir_library(MLIRArithTestPasses + TestAssert.cpp + + EXCLUDE_FROM_LIBMLIR + + LINK_LIBS PUBLIC + MLIRControlFlowToLLVM + MLIRFuncDialect + MLIRLLVMCommonConversion + MLIRLLVMDialect + MLIRPass + MLIRTransforms +) diff --git a/mlir/test/lib/Dialect/ControlFlow/TestAssert.cpp b/mlir/test/lib/Dialect/ControlFlow/TestAssert.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Dialect/ControlFlow/TestAssert.cpp @@ -0,0 +1,59 @@ +//===- TestAssert.cpp - Test cf.assert Lowering ----------------*- c++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass for integration testing of wide integer +// emulation patterns. Applies conversion patterns only to functions whose +// names start with a specified prefix. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" +#include "mlir/Conversion/LLVMCommon/ConversionTarget.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +using namespace mlir; + +namespace { +struct TestAssertPass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAssertPass) + + TestAssertPass() = default; + TestAssertPass(const TestAssertPass &pass) : PassWrapper(pass) {} + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + StringRef getArgument() const final { return "test-cf-assert"; } + StringRef getDescription() const final { + return "Function pass to test cf.assert lowering to LLVM without abort"; + } + + void runOnOperation() override { + LLVMConversionTarget target(getContext()); + RewritePatternSet patterns(&getContext()); + + LLVMTypeConverter converter(&getContext()); + mlir::cf::populateAssertToLLVMConversionPattern(converter, patterns, + /*abortOnFailure=*/false); + + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) + signalPassFailure(); + } +}; +} // namespace + +namespace mlir::test { +void registerTestCfAssertPass() { PassRegistration(); } +} // namespace mlir::test diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -68,6 +68,7 @@ void registerTestAliasAnalysisPass(); void registerTestBuiltinAttributeInterfaces(); void registerTestCallGraphPass(); +void registerTestCfAssertPass(); void registerTestConstantFold(); void registerTestControlFlowSink(); void registerTestGpuSerializeToCubinPass(); @@ -168,6 +169,7 @@ mlir::test::registerTestArithEmulateWideIntPass(); mlir::test::registerTestBuiltinAttributeInterfaces(); mlir::test::registerTestCallGraphPass(); + mlir::test::registerTestCfAssertPass(); mlir::test::registerTestConstantFold(); mlir::test::registerTestControlFlowSink(); mlir::test::registerTestDiagnosticsPass(); diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -6941,6 +6941,7 @@ "//mlir/test:TestAffine", "//mlir/test:TestAnalysis", "//mlir/test:TestArith", + "//mlir/test:TestControlFlow", "//mlir/test:TestDLTI", "//mlir/test:TestDialect", "//mlir/test:TestFunc", diff --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel @@ -676,6 +676,21 @@ ], ) +cc_library( + name = "TestControlFlow", + srcs = glob(["lib/Dialect/ControlFlow/*.cpp"]), + includes = ["lib/Dialect/Test"], + deps = [ + "//mlir:ControlFlowDialect", + "//mlir:ControlFlowToLLVM", + "//mlir:FuncDialect", + "//mlir:LLVMCommonConversion", + "//mlir:LLVMDialect", + "//mlir:Pass", + "//mlir:Transforms", + ], +) + cc_library( name = "TestShapeDialect", srcs = [