diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp @@ -14,9 +14,125 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Interfaces/RuntimeVerifiableOpInterface.h" +using namespace mlir; + +/// Generate an error message string for the given op and the specified error. +static std::string generateErrorMessage(Operation *op, const std::string &msg) { + std::string buffer; + llvm::raw_string_ostream stream(buffer); + OpPrintingFlags flags; + stream << "ERROR: Runtime op verification failed\n"; + op->print(stream, flags); + stream << "\n^ " << msg; + stream << "\nLocation: "; + op->getLoc().print(stream); + return stream.str(); +} + namespace mlir { namespace memref { namespace { +struct CastOpInterface + : public RuntimeVerifiableOpInterface::ExternalModel { + void generateRuntimeVerification(Operation *op, OpBuilder &builder, + Location loc) const { + auto castOp = cast(op); + auto srcType = castOp.getSource().getType().cast(); + + // Nothing to check if the result is an unranked memref. + auto resultType = castOp.getType().dyn_cast(); + if (!resultType) + return; + + if (srcType.isa()) { + // Check rank. + Value srcRank = builder.create(loc, castOp.getSource()); + Value resultRank = + builder.create(loc, resultType.getRank()); + Value isSameRank = builder.create( + loc, arith::CmpIPredicate::eq, srcRank, resultRank); + builder.create(loc, isSameRank, + generateErrorMessage(op, "rank mismatch")); + } + + // Get source offset and strides. We do not have an op to get offsets and + // strides from unranked memrefs, so cast the source to a type with fully + // dynamic layout, from which we can then extract the offset and strides. + // (Rank was already verified.) + int64_t dynamicOffset = ShapedType::kDynamic; + SmallVector dynamicShape(resultType.getRank(), + ShapedType::kDynamic); + auto stridedLayout = StridedLayoutAttr::get(builder.getContext(), + dynamicOffset, dynamicShape); + auto dynStridesType = + MemRefType::get(dynamicShape, resultType.getElementType(), + stridedLayout, resultType.getMemorySpace()); + Value helperCast = + builder.create(loc, dynStridesType, castOp.getSource()); + auto metadataOp = builder.create(loc, helperCast); + + // Check dimension sizes. + for (const auto &it : llvm::enumerate(resultType.getShape())) { + // Static dim size -> static/dynamic dim size does not need verification. + if (auto rankedSrcType = srcType.dyn_cast()) + if (!rankedSrcType.isDynamicDim(it.index())) + continue; + + // Static/dynamic dim size -> dynamic dim size does not need verification. + if (resultType.isDynamicDim(it.index())) + continue; + + Value srcDimSz = + builder.create(loc, castOp.getSource(), it.index()); + Value resultDimSz = + builder.create(loc, it.value()); + Value isSameSz = builder.create( + loc, arith::CmpIPredicate::eq, srcDimSz, resultDimSz); + builder.create( + loc, isSameSz, + generateErrorMessage(op, "size mismatch of dim " + + std::to_string(it.index()))); + } + + // Get result offset and strides. + int64_t resultOffset; + SmallVector resultStrides; + if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset))) + return; + + // Check offset. + if (resultOffset != ShapedType::kDynamic) { + // Static/dynamic offset -> dynamic offset does not need verification. + Value srcOffset = metadataOp.getResult(1); + Value resultOffsetVal = + builder.create(loc, resultOffset); + Value isSameOffset = builder.create( + loc, arith::CmpIPredicate::eq, srcOffset, resultOffsetVal); + builder.create(loc, isSameOffset, + generateErrorMessage(op, "offset mismatch")); + } + + // Check strides. + for (const auto &it : llvm::enumerate(resultStrides)) { + // Static/dynamic stride -> dynamic stride does not need verification. + if (it.value() == ShapedType::kDynamic) + continue; + + Value srcStride = + metadataOp.getResult(2 + resultType.getRank() + it.index()); + Value resultStrideVal = + builder.create(loc, it.value()); + Value isSameStride = builder.create( + loc, arith::CmpIPredicate::eq, srcStride, resultStrideVal); + builder.create( + loc, isSameStride, + generateErrorMessage(op, "stride mismatch of dim " + + std::to_string(it.index()))); + } + } +}; + struct ExpandShapeOpInterface : public RuntimeVerifiableOpInterface::ExternalModel { @@ -53,7 +169,8 @@ builder.create(loc, 0)); builder.create( loc, isModZero, - "static result dims in reassoc group do not divide src dim evenly"); + generateErrorMessage(op, "static result dims in reassoc group do not " + "divide src dim evenly")); } } }; @@ -64,6 +181,7 @@ void mlir::memref::registerRuntimeVerifiableOpInterfaceExternalModels( DialectRegistry ®istry) { registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) { + CastOp::attachInterface(*ctx); ExpandShapeOp::attachInterface(*ctx); // Load additional dialects of which ops may get created. diff --git a/mlir/test/Dialect/MemRef/runtime-verification.mlir b/mlir/test/Dialect/MemRef/runtime-verification.mlir --- a/mlir/test/Dialect/MemRef/runtime-verification.mlir +++ b/mlir/test/Dialect/MemRef/runtime-verification.mlir @@ -7,7 +7,7 @@ // CHECK-DAG: %[[dim:.*]] = memref.dim %[[m]], %[[c0]] // CHECK: %[[mod:.*]] = arith.remsi %[[dim]], %[[c5]] // CHECK: %[[cmpi:.*]] = arith.cmpi eq, %[[mod]], %[[c0]] -// CHECK: cf.assert %[[cmpi]], "static result dims in reassoc group do not divide src dim evenly" +// CHECK: cf.assert %[[cmpi]], "ERROR: Runtime op verification failed func.func @expand_shape(%m: memref) -> memref { %0 = memref.expand_shape %m [[0, 1]] : memref into memref return %0 : memref diff --git a/mlir/test/Integration/Dialect/Memref/cast-runtime-verification.mlir b/mlir/test/Integration/Dialect/Memref/cast-runtime-verification.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/Dialect/Memref/cast-runtime-verification.mlir @@ -0,0 +1,69 @@ +// RUN: mlir-opt %s -generate-runtime-verification -convert-memref-to-llvm \ +// RUN: -test-cf-assert \ +// RUN: -convert-func-to-llvm -reconcile-unrealized-casts | \ +// RUN: mlir-cpu-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_lib_dir/libmlir_runner_utils%shlibext 2>&1 | \ +// RUN: FileCheck %s + +func.func @cast_to_static_dim(%m: memref) -> memref<10xf32> { + %0 = memref.cast %m : memref to memref<10xf32> + return %0 : memref<10xf32> +} + +func.func @cast_to_ranked(%m: memref<*xf32>) -> memref { + %0 = memref.cast %m : memref<*xf32> to memref + return %0 : memref +} + +func.func @cast_to_static_strides(%m: memref>) + -> memref> { + %0 = memref.cast %m : memref> + to memref> + return %0 : memref> +} + +func.func @valid_cast(%m: memref<*xf32>) -> memref { + %0 = memref.cast %m : memref<*xf32> to memref + return %0 : memref +} + +func.func @main() { + // All casts inside the called functions are invalid at runtime, except for + // the last one. + %alloc = memref.alloc() : memref<5xf32> + + // CHECK: ERROR: Runtime op verification failed + // CHECK-NEXT: memref.cast %{{.*}} : memref to memref<10xf32> + // CHECK-NEXT: ^ size mismatch of dim 0 + // CHECK-NEXT: Location: loc({{.*}}) + %1 = memref.cast %alloc : memref<5xf32> to memref + func.call @cast_to_static_dim(%1) : (memref) -> (memref<10xf32>) + + // CHECK-NEXT: ERROR: Runtime op verification failed + // CHECK-NEXT: memref.cast %{{.*}} : memref<*xf32> to memref + // CHECK-NEXT: ^ rank mismatch + // CHECK-NEXT: Location: loc({{.*}}) + %3 = memref.cast %alloc : memref<5xf32> to memref<*xf32> + func.call @cast_to_ranked(%3) : (memref<*xf32>) -> (memref) + + // CHECK-NEXT: ERROR: Runtime op verification failed + // CHECK-NEXT: memref.cast %{{.*}} : memref> to memref> + // CHECK-NEXT: ^ offset mismatch + // CHECK-NEXT: Location: loc({{.*}}) + + // CHECK-NEXT: ERROR: Runtime op verification failed + // CHECK-NEXT: memref.cast %{{.*}} : memref> to memref> + // CHECK-NEXT: ^ stride mismatch of dim 0 + // CHECK-NEXT: Location: loc({{.*}}) + %4 = memref.cast %alloc + : memref<5xf32> to memref> + func.call @cast_to_static_strides(%4) + : (memref>) + -> (memref>) + + // A last cast that actually succeeds. + // CHECK-NOT: ERROR: Runtime op verification failed + func.call @valid_cast(%3) : (memref<*xf32>) -> (memref) + + return +}