diff --git a/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td b/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td --- a/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td +++ b/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td @@ -41,21 +41,40 @@ let summary = "Assert operation with message attribute"; let description = [{ Assert operation with single boolean operand and an error message attribute. - If the argument is `true` this operation has no effect. Otherwise, the + If the `condition` is `true` this operation has no effect. Otherwise, the program execution will abort. The provided error message may be used by a runtime to propagate the error to the user. + + Additional error message args can be provided to produce more descriptive + error messages with runtime information. Message args are referenced with + "{}" placeholders inside the error message string. The number of message + args must match the number of placeholders. Example: ```mlir - assert %b, "Expected ... to be true" + assert %condition, "Expected ... to be true" + ``` + + ```mlir + assert %condition, "Expected {}, found {}"(%0, %1) : index, index ``` }]; - let arguments = (ins I1:$arg, StrAttr:$msg); + let arguments = (ins I1:$condition, StrAttr:$msg, Variadic:$msgArgs); + + let builders = [ + OpBuilder<(ins "Value":$condition)>, + OpBuilder<(ins "Value":$condition, "StringAttr":$msg)>, + OpBuilder<(ins "Value":$condition, "StringRef":$msg)>, + ]; + + let assemblyFormat = [{ + $condition `,` $msg (`(` $msgArgs^ `)`)? attr-dict (`:` type($msgArgs)^)? + }]; - let assemblyFormat = "$arg `,` $msg attr-dict"; let hasCanonicalizeMethod = 1; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Interfaces/RuntimeVerifiableOpInterface.h b/mlir/include/mlir/Interfaces/RuntimeVerifiableOpInterface.h --- a/mlir/include/mlir/Interfaces/RuntimeVerifiableOpInterface.h +++ b/mlir/include/mlir/Interfaces/RuntimeVerifiableOpInterface.h @@ -13,6 +13,19 @@ #include "llvm/ADT/SmallVector.h" +namespace mlir { +/// The condition of a runtime check can be assumed (asserted to be) satisfied +/// at runtime. The message and placeholders can be used to emit descriptive +/// error messages. +struct RuntimeVerificationCheck { + RuntimeVerificationCheck(Value condition, StringAttr msg = {}, + ArrayRef msgPlaceholders = {}); + + Value condition; + StringAttr msg; + SmallVector msgPlaceholders; +}; +} // namespace mlir /// Include the generated interface declarations. #include "mlir/Interfaces/RuntimeVerifiableOpInterface.h.inc" diff --git a/mlir/include/mlir/Interfaces/RuntimeVerifiableOpInterface.td b/mlir/include/mlir/Interfaces/RuntimeVerifiableOpInterface.td --- a/mlir/include/mlir/Interfaces/RuntimeVerifiableOpInterface.td +++ b/mlir/include/mlir/Interfaces/RuntimeVerifiableOpInterface.td @@ -26,12 +26,19 @@ let methods = [ InterfaceMethod< /*desc=*/[{ - Generate IR to verify this op at runtime. Return a list of i1-typed - values that are assumed (can be asserted to be) satisfied (`true`) at - runtime, along with corresponding error messages in case they are not - satisfied. + Generate IR to verify this op at runtime. Return a list of + `RuntimeVerificationCheck` objects. Each object consists of: + + 1. A condition value that is assumed (can be asserted to be) satisfied + (`true`) at runtime. + 2. An error message string in case the condition is not satisfied. The + error message may contain 'n' placeholders, each indicated by '{}'. + 3. 'n' placeholder values. + + Note: Placeholders can be used to generate descriptive error messages + based on runtime values. }], - /*retTy=*/"::llvm::SmallVector<::std::pair<::mlir::Value, ::mlir::StringAttr>>", + /*retTy=*/"::llvm::SmallVector<::mlir::RuntimeVerificationCheck>", /*methodName=*/"generateRuntimeVerification", /*args=*/(ins "::mlir::OpBuilder &":$builder, "::mlir::Location":$loc) diff --git a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp --- a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp +++ b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp @@ -23,7 +23,9 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" + #include "llvm/ADT/StringRef.h" +#include "llvm/ADT/TypeSwitch.h" #include namespace mlir { @@ -35,25 +37,24 @@ #define PASS_NAME "convert-cf-to-llvm" -static std::string generateGlobalMsgSymbolName(ModuleOp moduleOp) { - std::string prefix = "assert_msg_"; +static std::string generateGlobalStrSymbolName(ModuleOp moduleOp) { + std::string prefix = "assert_str_"; int counter = 0; while (moduleOp.lookupSymbol(prefix + std::to_string(counter))) ++counter; return prefix + std::to_string(counter); } -/// Generate IR that prints the given string to stderr. -static void createPrintMsg(OpBuilder &builder, Location loc, ModuleOp moduleOp, - StringRef msg) { +/// Generate IR that prints the given string. +static void createPrintStr(OpBuilder &builder, Location loc, ModuleOp moduleOp, + StringRef str) { auto ip = builder.saveInsertionPoint(); builder.setInsertionPointToStart(moduleOp.getBody()); MLIRContext *ctx = builder.getContext(); // Create a zero-terminated byte representation and allocate global symbol. SmallVector elementVals; - elementVals.append(msg.begin(), msg.end()); - elementVals.push_back('\n'); + elementVals.append(str.begin(), str.end()); elementVals.push_back(0); auto dataAttrType = RankedTensorType::get( {static_cast(elementVals.size())}, builder.getI8Type()); @@ -61,23 +62,67 @@ DenseElementsAttr::get(dataAttrType, llvm::makeArrayRef(elementVals)); auto arrayTy = LLVM::LLVMArrayType::get(IntegerType::get(ctx, 8), elementVals.size()); - std::string symbolName = generateGlobalMsgSymbolName(moduleOp); + std::string symbolName = generateGlobalStrSymbolName(moduleOp); auto globalOp = builder.create( loc, arrayTy, /*constant=*/true, LLVM::Linkage::Private, symbolName, dataAttr); // Emit call to `printStr` in runtime library. builder.restoreInsertionPoint(ip); - auto msgAddr = builder.create( + auto strAddr = builder.create( loc, LLVM::LLVMPointerType::get(arrayTy), globalOp.getName()); SmallVector indices(1, 0); Value gep = builder.create( - loc, LLVM::LLVMPointerType::get(builder.getI8Type()), msgAddr, indices); + loc, LLVM::LLVMPointerType::get(builder.getI8Type()), strAddr, indices); Operation *printer = LLVM::lookupOrCreatePrintStrFn(moduleOp); builder.create(loc, TypeRange(), SymbolRefAttr::get(printer), gep); } +/// Generate IR that prints the given value. +static void createPrintValue(OpBuilder &builder, + LLVMTypeConverter *typeConverter, Location loc, + ModuleOp moduleOp, Value value) { + Operation *printer = + llvm::TypeSwitch(value.getType()) + .Case([&](IntegerType intTy) { + if (intTy.isInteger(64) && !intTy.isUnsigned()) + return LLVM::lookupOrCreatePrintI64Fn(moduleOp).getOperation(); + return static_cast(nullptr); + }) + .Default([](Type t) { return nullptr; }); + + // Add support for additional types as needed. + assert(printer && "unsupported type"); + builder.create(loc, TypeRange(), SymbolRefAttr::get(printer), + value); +} + +/// Generate IR that prints the given message and interleaved message args. +static void createPrintMsg(OpBuilder &builder, LLVMTypeConverter *typeConverter, + Location loc, ModuleOp moduleOp, StringRef msg, + ValueRange msgArgs) { + StringRef remainingStr = msg; + ValueRange remainingArgs = msgArgs; + while (!remainingStr.empty()) { + bool hasPlaceholder = remainingStr.contains("{}"); + std::pair sep = remainingStr.split("{}"); + createPrintStr(builder, loc, moduleOp, sep.first); + if (hasPlaceholder) { + // A separator was found. + createPrintValue(builder, typeConverter, loc, moduleOp, + remainingArgs.front()); + remainingArgs = remainingArgs.drop_front(); + } + remainingStr = sep.second; + } + + // Print new line. + Operation *newLinePrinter = LLVM::lookupOrCreatePrintNewlineFn(moduleOp); + builder.create( + loc, TypeRange(), SymbolRefAttr::get(newLinePrinter), ValueRange()); +} + namespace { /// Lower `cf.assert`. The default lowering calls the `abort` function if the /// assertion is violated and has no effect otherwise. The failure message is @@ -102,7 +147,8 @@ // Failed block: Generate IR to print the message and call `abort`. Block *failureBlock = rewriter.createBlock(opBlock->getParent()); - createPrintMsg(rewriter, loc, module, op.getMsg()); + createPrintMsg(rewriter, getTypeConverter(), loc, module, op.getMsg(), + adaptor.getMsgArgs()); if (abortOnFailedAssert) { // Insert the `abort` declaration if necessary. auto abortFunc = module.lookupSymbol("abort"); @@ -122,7 +168,7 @@ // Generate assertion test. rewriter.setInsertionPointToEnd(opBlock); rewriter.replaceOpWithNewOp( - op, adaptor.getArg(), continuationBlock, failureBlock); + op, adaptor.getCondition(), continuationBlock, failureBlock); return success(); } diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp --- a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp +++ b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp @@ -710,7 +710,7 @@ Block *cont = rewriter.splitBlock(op->getBlock(), Block::iterator(op)); rewriter.setInsertionPointToEnd(cont->getPrevNode()); - rewriter.create(loc, adaptor.getArg(), + rewriter.create(loc, adaptor.getCondition(), /*trueDest=*/cont, /*trueArgs=*/ArrayRef(), /*falseDest=*/setupSetErrorBlock(coro), diff --git a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp --- a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp +++ b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp @@ -77,13 +77,36 @@ LogicalResult AssertOp::canonicalize(AssertOp op, PatternRewriter &rewriter) { // Erase assertion if argument is constant true. - if (matchPattern(op.getArg(), m_One())) { + if (matchPattern(op.getCondition(), m_One())) { rewriter.eraseOp(op); return success(); } return failure(); } +void AssertOp::build(OpBuilder &builder, OperationState &result, + Value condition) { + build(builder, result, condition, /*msg=*/StringAttr()); +} + +void AssertOp::build(OpBuilder &builder, OperationState &result, + Value condition, StringAttr msg) { + build(builder, result, condition, msg, /*msgArgs=*/ValueRange()); +} + +void AssertOp::build(OpBuilder &builder, OperationState &result, + Value condition, StringRef msg) { + build(builder, result, condition, msg, /*msgArgs=*/ValueRange()); +} + +LogicalResult AssertOp::verify() { + int64_t expectedNumMsgArgs = getMsg().count("{}"); + if (expectedNumMsgArgs != getMsgArgs().size()) + return emitOpError() << "expected " << expectedNumMsgArgs + << " msgArgs but found " << getMsgArgs().size(); + return success(); +} + //===----------------------------------------------------------------------===// // BranchOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/ControlFlow/Transforms/GenerateRuntimeVerification.cpp b/mlir/lib/Dialect/ControlFlow/Transforms/GenerateRuntimeVerification.cpp --- a/mlir/lib/Dialect/ControlFlow/Transforms/GenerateRuntimeVerification.cpp +++ b/mlir/lib/Dialect/ControlFlow/Transforms/GenerateRuntimeVerification.cpp @@ -34,10 +34,11 @@ getOperation()->walk([&](Operation *op) { if (auto verifiableOp = dyn_cast(op)) { builder.setInsertionPoint(op); - SmallVector> checks = + SmallVector checks = verifiableOp.generateRuntimeVerification(builder, op->getLoc()); for (const auto &check : checks) - builder.create(op->getLoc(), check.first, check.second); + builder.create(op->getLoc(), check.condition, check.msg, + check.msgPlaceholders); } }); } diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp @@ -12,23 +12,121 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Interfaces/RuntimeVerifiableOpInterface.h" +using namespace mlir; +using namespace mlir::memref; + +/// Generate a string representation of the given MemRef type with placeholders. +static std::pair> +generateMemRefTypePrinter(BaseMemRefType type, + ExtractStridedMetadataOp metadataOp) { + SmallVector placeholders; + std::string buffer; + llvm::raw_string_ostream stream(buffer); + + // Unranked memref: directly print the type. + if (auto unrankedMemrefType = type.dyn_cast()) { + unrankedMemrefType.print(stream); + return std::make_pair(stream.str(), placeholders); + } + + // Print shape of the memref. + auto rankedMemrefType = type.dyn_cast(); + stream << "memref<"; + for (int64_t i = 0; i < rankedMemrefType.getRank(); ++i) { + if (rankedMemrefType.isDynamicDim(i)) { + stream << "{}"; + placeholders.push_back(metadataOp.getResult(2 + i)); + } else { + stream << std::to_string(rankedMemrefType.getDimSize(i)); + } + stream << "x"; + } + + // Print element type. + rankedMemrefType.getElementType().print(stream); + + // Get result offset and strides. + int64_t resultOffset; + SmallVector resultStrides; + if (failed( + getStridesAndOffset(rankedMemrefType, resultStrides, resultOffset))) + llvm_unreachable("could not get strides and offset"); + + // Print strides and offset. + stream << ", strided<["; + for (int64_t i = 0; i < rankedMemrefType.getRank(); ++i) { + if (resultStrides[i] == ShapedType::kDynamic) { + stream << "{}"; + placeholders.push_back( + metadataOp.getResult(2 + rankedMemrefType.getRank() + i)); + } else { + stream << std::to_string(resultStrides[i]); + } + + if (i < rankedMemrefType.getRank() - 1) + stream << ", "; + } + stream << "], offset: "; + if (resultOffset == ShapedType::kDynamic) { + stream << "{}"; + placeholders.push_back(metadataOp.getResult(1)); + } else { + stream << std::to_string(resultOffset); + } + stream << ">"; + + return std::make_pair(stream.str(), placeholders); +} + +/// Generate an "invalid" memref.cast error message with placeholders for +/// dynamic dims/strides/offset. +static std::pair> +generateInvalidCastMessage(OpBuilder &builder, MemRefType srcType, + MemRefType targetType, + ExtractStridedMetadataOp metadataOp) { + SmallVector placeholders; + std::string srcTypeStr; + std::tie(srcTypeStr, placeholders) = + generateMemRefTypePrinter(srcType, metadataOp); + + std::string buffer; + llvm::raw_string_ostream stream(buffer); + stream << "memref.cast: invalid cast from " << srcTypeStr << " to "; + targetType.print(stream); + return std::make_pair(stream.str(), placeholders); +} + +/// Generate a RuntimeVerificationCheck with the given condition for an invalid +/// memref.cast op. +static RuntimeVerificationCheck +generateInvalidCastCheck(OpBuilder &builder, Value condition, + MemRefType srcType, MemRefType targetType, + ExtractStridedMetadataOp metadataOp) { + std::string msg; + SmallVector placeholders; + std::tie(msg, placeholders) = + generateInvalidCastMessage(builder, srcType, targetType, metadataOp); + return RuntimeVerificationCheck(condition, builder.getStringAttr(msg), + placeholders); +} + namespace mlir { namespace memref { namespace { struct CastOpInterface : public RuntimeVerifiableOpInterface::ExternalModel { - SmallVector> + SmallVector generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc) const { auto castOp = cast(op); auto srcType = castOp.getSource().getType().cast(); - SmallVector> result; + SmallVector checks; // Nothing to check if the result is an unranked memref. auto resultType = castOp.getType().dyn_cast(); if (!resultType) - return result; + return checks; if (srcType.isa()) { // Check rank. @@ -37,10 +135,29 @@ builder.create(loc, resultType.getRank()); Value isSameRank = builder.create( loc, arith::CmpIPredicate::eq, srcRank, resultRank); - result.emplace_back( - isSameRank, builder.getStringAttr("memref::CastOp: rank mismatch")); + checks.emplace_back( + isSameRank, + builder.getStringAttr("memref.cast: invalid cast from rank {} to " + + std::to_string(resultType.getRank())), + ArrayRef{srcRank}); } + // Get source offset and strides. We do not have an op to get extract + // offsets and strides from unranked memrefs, so cast the source to a type + // with fully dynamic layout, from which we can then extract the offset and + // strides. (Rank was already verified.) + int64_t dynamicOffset = ShapedType::kDynamic; + SmallVector dynamicShape(resultType.getRank(), + ShapedType::kDynamic); + auto stridedLayout = StridedLayoutAttr::get(builder.getContext(), + dynamicOffset, dynamicShape); + auto rankedSrcType = + MemRefType::get(dynamicShape, resultType.getElementType(), + stridedLayout, resultType.getMemorySpace()); + Value helperCast = + builder.create(loc, rankedSrcType, castOp.getSource()); + auto metadataOp = builder.create(loc, helperCast); + // Check dimension sizes. for (const auto &it : llvm::enumerate(resultType.getShape())) { // Static dim size -> static/dynamic dim size does not need verification. @@ -58,33 +175,15 @@ builder.create(loc, it.value()); Value isSameSz = builder.create( loc, arith::CmpIPredicate::eq, srcDimSz, resultDimSz); - result.emplace_back( - isSameSz, - builder.getStringAttr("memref::CastOp: size mismatch of dim " + - std::to_string(it.index()))); + checks.push_back(generateInvalidCastCheck( + builder, isSameSz, rankedSrcType, resultType, metadataOp)); } - // Get source offset and strides. We do not have an op to get extract - // offsets and strides from unranked memrefs, so cast the source to a type - // with fully dynamic layout, from which we can then extract the offset and - // strides. (Rank was already verified.) - int64_t dynamicOffset = ShapedType::kDynamic; - SmallVector dynamicShape(resultType.getRank(), - ShapedType::kDynamic); - auto stridedLayout = StridedLayoutAttr::get(builder.getContext(), - dynamicOffset, dynamicShape); - auto dynStridesType = - MemRefType::get(dynamicShape, resultType.getElementType(), - stridedLayout, resultType.getMemorySpace()); - Value helperCast = - builder.create(loc, dynStridesType, castOp.getSource()); - auto metadataOp = builder.create(loc, helperCast); - // Get result offset and strides. int64_t resultOffset; SmallVector resultStrides; if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset))) - return result; + llvm_unreachable("could not get strides and offset"); // Check offset. if (resultOffset != ShapedType::kDynamic) { @@ -94,8 +193,8 @@ builder.create(loc, resultOffset); Value isSameOffset = builder.create( loc, arith::CmpIPredicate::eq, srcOffset, resultOffsetVal); - result.emplace_back(isSameOffset, builder.getStringAttr( - "memref::CastOp: offset mismatch")); + checks.push_back(generateInvalidCastCheck( + builder, isSameOffset, rankedSrcType, resultType, metadataOp)); } // Check strides. @@ -110,24 +209,22 @@ builder.create(loc, it.value()); Value isSameStride = builder.create( loc, arith::CmpIPredicate::eq, srcStride, resultStrideVal); - result.emplace_back( - isSameStride, - builder.getStringAttr("memref::CastOp: stride mismatch of dim " + - std::to_string(it.index()))); + checks.push_back(generateInvalidCastCheck( + builder, isSameStride, rankedSrcType, resultType, metadataOp)); } - return result; + return checks; } }; struct ExpandShapeOpInterface : public RuntimeVerifiableOpInterface::ExternalModel { - SmallVector> + SmallVector generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc) const { auto expandShapeOp = cast(op); - SmallVector> result; + SmallVector checks; // Verify that the expanded dim sizes are a product of the collapsed dim // size. @@ -154,13 +251,13 @@ Value isModZero = builder.create( loc, arith::CmpIPredicate::eq, mod, builder.create(loc, 0)); - result.emplace_back( + checks.emplace_back( isModZero, builder.getStringAttr("memref::ExpandShapeOp: static result dims in " "reassoc group do not divide src dim evenly")); } - return result; + return checks; } }; } // namespace diff --git a/mlir/lib/ExecutionEngine/CRunnerUtils.cpp b/mlir/lib/ExecutionEngine/CRunnerUtils.cpp --- a/mlir/lib/ExecutionEngine/CRunnerUtils.cpp +++ b/mlir/lib/ExecutionEngine/CRunnerUtils.cpp @@ -48,7 +48,7 @@ // For debug assertions only: Print the given string. extern "C" void printStr(int8_t *str) { - fputs(reinterpret_cast(str), stderr); + fputs(reinterpret_cast(str), stdout); } extern "C" void memrefCopy(int64_t elemSize, UnrankedMemRefType *srcArg, diff --git a/mlir/lib/Interfaces/RuntimeVerifiableOpInterface.cpp b/mlir/lib/Interfaces/RuntimeVerifiableOpInterface.cpp --- a/mlir/lib/Interfaces/RuntimeVerifiableOpInterface.cpp +++ b/mlir/lib/Interfaces/RuntimeVerifiableOpInterface.cpp @@ -13,5 +13,14 @@ class OpBuilder; } // namespace mlir +using namespace mlir; + +RuntimeVerificationCheck::RuntimeVerificationCheck( + Value condition, StringAttr msg, ArrayRef msgPlaceholders) + : condition(condition), msg(msg), msgPlaceholders(msgPlaceholders) { + assert(!msg || msg.strref().count("{}") == msgPlaceholders.size() && + "invalid number of placeholders"); +} + /// Include the definitions of the interface. #include "mlir/Interfaces/RuntimeVerifiableOpInterface.cpp.inc" diff --git a/mlir/test/Integration/Dialect/ControlFlow/assert.mlir b/mlir/test/Integration/Dialect/ControlFlow/assert.mlir --- a/mlir/test/Integration/Dialect/ControlFlow/assert.mlir +++ b/mlir/test/Integration/Dialect/ControlFlow/assert.mlir @@ -1,5 +1,5 @@ // RUN: mlir-opt %s -convert-cf-to-llvm="abort-on-failed-assert=0" \ -// RUN: -convert-func-to-llvm | \ +// RUN: -convert-func-to-llvm -reconcile-unrealized-casts | \ // RUN: mlir-cpu-runner -e main -entry-point-result=void \ // RUN: -shared-libs=%mlir_lib_dir/libmlir_runner_utils%shlibext 2>&1 | \ // RUN: FileCheck %s @@ -7,9 +7,12 @@ func.func @main() { %a = arith.constant 0 : i1 %b = arith.constant 1 : i1 + %cst = arith.constant 123 : index // CHECK: assertion foo cf.assert %a, "assertion foo" // CHECK-NOT: assertion bar cf.assert %b, "assertion bar" + // CHECK: assertion 123 foo + cf.assert %a, "assertion {} foo"(%cst) : index return } diff --git a/mlir/test/Integration/Dialect/Memref/cast-runtime-verification.mlir b/mlir/test/Integration/Dialect/Memref/cast-runtime-verification.mlir --- a/mlir/test/Integration/Dialect/Memref/cast-runtime-verification.mlir +++ b/mlir/test/Integration/Dialect/Memref/cast-runtime-verification.mlir @@ -26,16 +26,18 @@ // All casts inside the called functions are invalid at runtime. %alloc = memref.alloc() : memref<5xf32> - // CHECK: memref::CastOp: size mismatch of dim 0 + // CHECK: memref.cast: invalid cast from memref<5xf32, strided<[1], offset: 0> to memref<10xf32> %1 = memref.cast %alloc : memref<5xf32> to memref func.call @cast_to_static_dim(%1) : (memref) -> (memref<10xf32>) - // CHECK-NEXT: memref::CastOp: rank mismatch + // CHECK-NEXT: memref.cast: invalid cast from rank 1 to 0 %3 = memref.cast %alloc : memref<5xf32> to memref<*xf32> func.call @cast_to_ranked(%3) : (memref<*xf32>) -> (memref) - // CHECK-NEXT: memref::CastOp: offset mismatch - // CHECK-NEXT: memref::CastOp: stride mismatch of dim 0 + // Next error is printed twice (because of abort-on-failed-assert=0). One for + // the stride and one for the offset. + // CHECK-NEXT: memref.cast: invalid cast from memref<5xf32, strided<[1], offset: 0> to memref> + // CHECK-NEXT: memref.cast: invalid cast from memref<5xf32, strided<[1], offset: 0> to memref> %4 = memref.cast %alloc : memref<5xf32> to memref> func.call @cast_to_static_strides(%4)