diff --git a/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td b/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td --- a/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td +++ b/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td @@ -41,21 +41,40 @@ let summary = "Assert operation with message attribute"; let description = [{ Assert operation with single boolean operand and an error message attribute. - If the argument is `true` this operation has no effect. Otherwise, the + If the `condition` is `true` this operation has no effect. Otherwise, the program execution will abort. The provided error message may be used by a runtime to propagate the error to the user. + + Additional error message args can be provided to produce more descriptive + error messages with runtime information. Message args are referenced with + "{}" placeholders inside the error message string. The number of message + args must match the number of placeholders. Example: ```mlir - assert %b, "Expected ... to be true" + assert %condition, "Expected ... to be true" + ``` + + ```mlir + assert %condition, "Expected {}, found {}"(%0, %1) : index, index ``` }]; - let arguments = (ins I1:$arg, StrAttr:$msg); + let arguments = (ins I1:$condition, StrAttr:$msg, Variadic:$msgArgs); + + let builders = [ + OpBuilder<(ins "Value":$condition)>, + OpBuilder<(ins "Value":$condition, "StringAttr":$msg)>, + OpBuilder<(ins "Value":$condition, "StringRef":$msg)>, + ]; + + let assemblyFormat = [{ + $condition `,` $msg (`(` $msgArgs^ `)`)? attr-dict (`:` type($msgArgs)^)? + }]; - let assemblyFormat = "$arg `,` $msg attr-dict"; let hasCanonicalizeMethod = 1; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp --- a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp +++ b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp @@ -23,7 +23,9 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" + #include "llvm/ADT/StringRef.h" +#include "llvm/ADT/TypeSwitch.h" #include namespace mlir { @@ -35,24 +37,24 @@ #define PASS_NAME "convert-cf-to-llvm" -static std::string generateGlobalMsgSymbolName(ModuleOp moduleOp) { - std::string prefix = "assert_msg_"; +static std::string generateGlobalStrSymbolName(ModuleOp moduleOp) { + std::string prefix = "assert_str_"; int counter = 0; while (moduleOp.lookupSymbol(prefix + std::to_string(counter))) ++counter; return prefix + std::to_string(counter); } -/// Generate IR that prints the given string to stderr. -static void createPrintMsg(OpBuilder &builder, Location loc, ModuleOp moduleOp, - StringRef msg) { +/// Generate IR that prints the given string. +static void createPrintStr(OpBuilder &builder, Location loc, ModuleOp moduleOp, + StringRef str) { auto ip = builder.saveInsertionPoint(); builder.setInsertionPointToStart(moduleOp.getBody()); MLIRContext *ctx = builder.getContext(); // Create a zero-terminated byte representation and allocate global symbol. SmallVector elementVals; - elementVals.append(msg.begin(), msg.end()); + elementVals.append(str.begin(), str.end()); elementVals.push_back(0); auto dataAttrType = RankedTensorType::get( {static_cast(elementVals.size())}, builder.getI8Type()); @@ -60,23 +62,67 @@ DenseElementsAttr::get(dataAttrType, llvm::makeArrayRef(elementVals)); auto arrayTy = LLVM::LLVMArrayType::get(IntegerType::get(ctx, 8), elementVals.size()); - std::string symbolName = generateGlobalMsgSymbolName(moduleOp); + std::string symbolName = generateGlobalStrSymbolName(moduleOp); auto globalOp = builder.create( loc, arrayTy, /*constant=*/true, LLVM::Linkage::Private, symbolName, dataAttr); // Emit call to `printStr` in runtime library. builder.restoreInsertionPoint(ip); - auto msgAddr = builder.create( + auto strAddr = builder.create( loc, LLVM::LLVMPointerType::get(arrayTy), globalOp.getName()); SmallVector indices(1, 0); Value gep = builder.create( - loc, LLVM::LLVMPointerType::get(builder.getI8Type()), msgAddr, indices); + loc, LLVM::LLVMPointerType::get(builder.getI8Type()), strAddr, indices); Operation *printer = LLVM::lookupOrCreatePrintStrFn(moduleOp); builder.create(loc, TypeRange(), SymbolRefAttr::get(printer), gep); } +/// Generate IR that prints the given value. +static void createPrintValue(OpBuilder &builder, + LLVMTypeConverter *typeConverter, Location loc, + ModuleOp moduleOp, Value value) { + Operation *printer = + llvm::TypeSwitch(value.getType()) + .Case([&](IntegerType intTy) { + if (intTy.isInteger(64) && !intTy.isUnsigned()) + return LLVM::lookupOrCreatePrintI64Fn(moduleOp).getOperation(); + return static_cast(nullptr); + }) + .Default([](Type t) { return nullptr; }); + + // Add support for additional types as needed. + assert(printer && "unsupported type"); + builder.create(loc, TypeRange(), SymbolRefAttr::get(printer), + value); +} + +/// Generate IR that prints the given message and interleaved message args. +static void createPrintMsg(OpBuilder &builder, LLVMTypeConverter *typeConverter, + Location loc, ModuleOp moduleOp, StringRef msg, + ValueRange msgArgs) { + StringRef remainingStr = msg; + ValueRange remainingArgs = msgArgs; + while (!remainingStr.empty()) { + bool hasPlaceholder = remainingStr.contains("{}"); + std::pair sep = remainingStr.split("{}"); + createPrintStr(builder, loc, moduleOp, sep.first); + if (hasPlaceholder) { + // A separator was found. + createPrintValue(builder, typeConverter, loc, moduleOp, + remainingArgs.front()); + remainingArgs = remainingArgs.drop_front(); + } + remainingStr = sep.second; + } + + // Print new line. + Operation *newLinePrinter = LLVM::lookupOrCreatePrintNewlineFn(moduleOp); + builder.create( + loc, TypeRange(), SymbolRefAttr::get(newLinePrinter), ValueRange()); +} + namespace { /// Lower `cf.assert`. The default lowering calls the `abort` function if the /// assertion is violated and has no effect otherwise. The failure message is @@ -101,7 +147,8 @@ // Failed block: Generate IR to print the message and call `abort`. Block *failureBlock = rewriter.createBlock(opBlock->getParent()); - createPrintMsg(rewriter, loc, module, op.getMsg()); + createPrintMsg(rewriter, getTypeConverter(), loc, module, op.getMsg(), + adaptor.getMsgArgs()); if (abortOnFailedAssert) { // Insert the `abort` declaration if necessary. auto abortFunc = module.lookupSymbol("abort"); @@ -121,7 +168,7 @@ // Generate assertion test. rewriter.setInsertionPointToEnd(opBlock); rewriter.replaceOpWithNewOp( - op, adaptor.getArg(), continuationBlock, failureBlock); + op, adaptor.getCondition(), continuationBlock, failureBlock); return success(); } diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp --- a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp +++ b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp @@ -730,7 +730,7 @@ Block *cont = rewriter.splitBlock(op->getBlock(), Block::iterator(op)); rewriter.setInsertionPointToEnd(cont->getPrevNode()); - rewriter.create(loc, adaptor.getArg(), + rewriter.create(loc, adaptor.getCondition(), /*trueDest=*/cont, /*trueArgs=*/ArrayRef(), /*falseDest=*/setupSetErrorBlock(coro), diff --git a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp --- a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp +++ b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp @@ -77,13 +77,36 @@ LogicalResult AssertOp::canonicalize(AssertOp op, PatternRewriter &rewriter) { // Erase assertion if argument is constant true. - if (matchPattern(op.getArg(), m_One())) { + if (matchPattern(op.getCondition(), m_One())) { rewriter.eraseOp(op); return success(); } return failure(); } +void AssertOp::build(OpBuilder &builder, OperationState &result, + Value condition) { + build(builder, result, condition, /*msg=*/StringAttr()); +} + +void AssertOp::build(OpBuilder &builder, OperationState &result, + Value condition, StringAttr msg) { + build(builder, result, condition, msg, /*msgArgs=*/ValueRange()); +} + +void AssertOp::build(OpBuilder &builder, OperationState &result, + Value condition, StringRef msg) { + build(builder, result, condition, msg, /*msgArgs=*/ValueRange()); +} + +LogicalResult AssertOp::verify() { + int64_t expectedNumMsgArgs = getMsg().count("{}"); + if (expectedNumMsgArgs != getMsgArgs().size()) + return emitOpError() << "expected " << expectedNumMsgArgs + << " msgArgs but found " << getMsgArgs().size(); + return success(); +} + //===----------------------------------------------------------------------===// // BranchOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp @@ -28,7 +28,7 @@ static constexpr llvm::StringRef kPrintU64 = "printU64"; static constexpr llvm::StringRef kPrintF32 = "printF32"; static constexpr llvm::StringRef kPrintF64 = "printF64"; -static constexpr llvm::StringRef kPrintStr = "puts"; +static constexpr llvm::StringRef kPrintStr = "printStr"; static constexpr llvm::StringRef kPrintOpen = "printOpen"; static constexpr llvm::StringRef kPrintClose = "printClose"; static constexpr llvm::StringRef kPrintComma = "printComma"; diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp @@ -13,6 +13,68 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Interfaces/RuntimeVerifiableOpInterface.h" +using namespace mlir; +using namespace mlir::memref; + +/// Generate a string representation of the given MemRef type with placeholders. +/// Placeholders are represented by {} in the result string. +static std::pair> +generateMemRefTypePrinter(BaseMemRefType type, + ExtractStridedMetadataOp metadataOp) { + SmallVector placeholders; + std::string buffer; + llvm::raw_string_ostream stream(buffer); + + // Unranked memref: directly print the type. + if (auto unrankedMemrefType = type.dyn_cast()) { + unrankedMemrefType.print(stream); + return std::make_pair(stream.str(), placeholders); + } + + // Print shape of the memref. + auto rankedMemrefType = type.dyn_cast(); + stream << "memref<"; + for (int64_t i = 0; i < rankedMemrefType.getRank(); ++i) { + stream << "{}x"; + placeholders.push_back(metadataOp.getResult(2 + i)); + } + + // Print element type. + rankedMemrefType.getElementType().print(stream); + + // Print strides and offset. + stream << ", strided<["; + for (int64_t i = 0; i < rankedMemrefType.getRank(); ++i) { + stream << "{}"; + placeholders.push_back( + metadataOp.getResult(2 + rankedMemrefType.getRank() + i)); + if (i < rankedMemrefType.getRank() - 1) + stream << ", "; + } + stream << "], offset: {}>"; + placeholders.push_back(metadataOp.getResult(1)); + + return std::make_pair(stream.str(), placeholders); +} + +/// Generate a runtime assertion with the given condition value and an +/// "invalid cast" error message indicating the type mismatch. +static void generateInvalidCastCheck(OpBuilder &builder, Location loc, + Value condition, MemRefType srcType, + MemRefType targetType, + ExtractStridedMetadataOp metadataOp) { + SmallVector placeholders; + std::string srcTypeStr; + std::tie(srcTypeStr, placeholders) = + generateMemRefTypePrinter(srcType, metadataOp); + + std::string buffer; + llvm::raw_string_ostream stream(buffer); + stream << "memref.cast: invalid cast from " << srcTypeStr << " to "; + targetType.print(stream); + builder.create(loc, condition, stream.str(), placeholders); +} + namespace mlir { namespace memref { namespace { @@ -36,10 +98,29 @@ builder.create(loc, resultType.getRank()); Value isSameRank = builder.create( loc, arith::CmpIPredicate::eq, srcRank, resultRank); - builder.create(loc, isSameRank, - "memref::CastOp: rank mismatch"); + builder.create( + loc, isSameRank, + "memref.cast: invalid cast from rank {} to " + + std::to_string(resultType.getRank()), + ArrayRef{srcRank}); } + // Get source offset and strides. We do not have an op to get offsets and + // strides from unranked memrefs, so cast the source to a type with fully + // dynamic layout, from which we can then extract the offset and strides. + // (Rank was already verified.) + int64_t dynamicOffset = ShapedType::kDynamic; + SmallVector dynamicShape(resultType.getRank(), + ShapedType::kDynamic); + auto stridedLayout = StridedLayoutAttr::get(builder.getContext(), + dynamicOffset, dynamicShape); + auto rankedSrcType = + MemRefType::get(dynamicShape, resultType.getElementType(), + stridedLayout, resultType.getMemorySpace()); + Value helperCast = + builder.create(loc, rankedSrcType, castOp.getSource()); + auto metadataOp = builder.create(loc, helperCast); + // Check dimension sizes. for (const auto &it : llvm::enumerate(resultType.getShape())) { // Static dim size -> static/dynamic dim size does not need verification. @@ -57,32 +138,15 @@ builder.create(loc, it.value()); Value isSameSz = builder.create( loc, arith::CmpIPredicate::eq, srcDimSz, resultDimSz); - builder.create(loc, isSameSz, - "memref::CastOp: size mismatch of dim " + - std::to_string(it.index())); + generateInvalidCastCheck(builder, loc, isSameSz, rankedSrcType, + resultType, metadataOp); } - // Get source offset and strides. We do not have an op to get extract - // offsets and strides from unranked memrefs, so cast the source to a type - // with fully dynamic layout, from which we can then extract the offset and - // strides. (Rank was already verified.) - int64_t dynamicOffset = ShapedType::kDynamic; - SmallVector dynamicShape(resultType.getRank(), - ShapedType::kDynamic); - auto stridedLayout = StridedLayoutAttr::get(builder.getContext(), - dynamicOffset, dynamicShape); - auto dynStridesType = - MemRefType::get(dynamicShape, resultType.getElementType(), - stridedLayout, resultType.getMemorySpace()); - Value helperCast = - builder.create(loc, dynStridesType, castOp.getSource()); - auto metadataOp = builder.create(loc, helperCast); - // Get result offset and strides. int64_t resultOffset; SmallVector resultStrides; if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset))) - return; + llvm_unreachable("could not get strides and offset"); // Check offset. if (resultOffset != ShapedType::kDynamic) { @@ -92,8 +156,8 @@ builder.create(loc, resultOffset); Value isSameOffset = builder.create( loc, arith::CmpIPredicate::eq, srcOffset, resultOffsetVal); - builder.create(loc, isSameOffset, - "memref::CastOp: offset mismatch"); + generateInvalidCastCheck(builder, loc, isSameOffset, rankedSrcType, + resultType, metadataOp); } // Check strides. @@ -108,9 +172,8 @@ builder.create(loc, it.value()); Value isSameStride = builder.create( loc, arith::CmpIPredicate::eq, srcStride, resultStrideVal); - builder.create(loc, isSameStride, - "memref::CastOp: stride mismatch of dim " + - std::to_string(it.index())); + generateInvalidCastCheck(builder, loc, isSameStride, rankedSrcType, + resultType, metadataOp); } } }; diff --git a/mlir/lib/ExecutionEngine/CRunnerUtils.cpp b/mlir/lib/ExecutionEngine/CRunnerUtils.cpp --- a/mlir/lib/ExecutionEngine/CRunnerUtils.cpp +++ b/mlir/lib/ExecutionEngine/CRunnerUtils.cpp @@ -56,6 +56,12 @@ extern "C" void printComma() { fputs(", ", stdout); } extern "C" void printNewline() { fputc('\n', stdout); } +// For debug assertions only: Print the given string. +extern "C" void printStr(int8_t *str) { + // Note: puts adds a new line, fputs does not add a new line. + fputs(reinterpret_cast(str), stdout); +} + extern "C" void memrefCopy(int64_t elemSize, UnrankedMemRefType *srcArg, UnrankedMemRefType *dstArg) { DynamicMemRefType src(*srcArg); diff --git a/mlir/test/Integration/Dialect/ControlFlow/assert.mlir b/mlir/test/Integration/Dialect/ControlFlow/assert.mlir --- a/mlir/test/Integration/Dialect/ControlFlow/assert.mlir +++ b/mlir/test/Integration/Dialect/ControlFlow/assert.mlir @@ -1,5 +1,5 @@ // RUN: mlir-opt %s -test-cf-assert \ -// RUN: -convert-func-to-llvm | \ +// RUN: -convert-func-to-llvm -reconcile-unrealized-casts | \ // RUN: mlir-cpu-runner -e main -entry-point-result=void \ // RUN: -shared-libs=%mlir_lib_dir/libmlir_runner_utils%shlibext | \ // RUN: FileCheck %s @@ -7,9 +7,12 @@ func.func @main() { %a = arith.constant 0 : i1 %b = arith.constant 1 : i1 + %cst = arith.constant 123 : index // CHECK: assertion foo cf.assert %a, "assertion foo" // CHECK-NOT: assertion bar cf.assert %b, "assertion bar" + // CHECK: assertion 123 foo + cf.assert %a, "assertion {} foo"(%cst) : index return } diff --git a/mlir/test/Integration/Dialect/Memref/cast-runtime-verification.mlir b/mlir/test/Integration/Dialect/Memref/cast-runtime-verification.mlir --- a/mlir/test/Integration/Dialect/Memref/cast-runtime-verification.mlir +++ b/mlir/test/Integration/Dialect/Memref/cast-runtime-verification.mlir @@ -26,16 +26,18 @@ // All casts inside the called functions are invalid at runtime. %alloc = memref.alloc() : memref<5xf32> - // CHECK: memref::CastOp: size mismatch of dim 0 + // CHECK: memref.cast: invalid cast from memref<5xf32, strided<[1], offset: 0> to memref<10xf32> %1 = memref.cast %alloc : memref<5xf32> to memref func.call @cast_to_static_dim(%1) : (memref) -> (memref<10xf32>) - // CHECK-NEXT: memref::CastOp: rank mismatch + // CHECK-NEXT: memref.cast: invalid cast from rank 1 to 0 %3 = memref.cast %alloc : memref<5xf32> to memref<*xf32> func.call @cast_to_ranked(%3) : (memref<*xf32>) -> (memref) - // CHECK-NEXT: memref::CastOp: offset mismatch - // CHECK-NEXT: memref::CastOp: stride mismatch of dim 0 + // Next error is printed twice (because of abort-on-failed-assert=0). One for + // the stride and one for the offset. + // CHECK-NEXT: memref.cast: invalid cast from memref<5xf32, strided<[1], offset: 0> to memref> + // CHECK-NEXT: memref.cast: invalid cast from memref<5xf32, strided<[1], offset: 0> to memref> %4 = memref.cast %alloc : memref<5xf32> to memref> func.call @cast_to_static_strides(%4)