diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp @@ -15,6 +15,111 @@ namespace mlir { namespace memref { namespace { +struct CastOpInterface + : public RuntimeVerifiableOpInterface::ExternalModel { + SmallVector> + generateRuntimeVerification(Operation *op, OpBuilder &builder, + Location loc) const { + auto castOp = cast(op); + auto srcType = castOp.getSource().getType().cast(); + SmallVector> result; + + // Nothing to check if the result is an unranked memref. + auto resultType = castOp.getType().dyn_cast(); + if (!resultType) + return result; + + if (srcType.isa()) { + // Check rank. + Value srcRank = builder.create(loc, castOp.getSource()); + Value resultRank = + builder.create(loc, resultType.getRank()); + Value isSameRank = builder.create( + loc, arith::CmpIPredicate::eq, srcRank, resultRank); + result.emplace_back( + isSameRank, builder.getStringAttr("memref::CastOp: rank mismatch")); + } + + // Check dimension sizes. + for (const auto &it : llvm::enumerate(resultType.getShape())) { + // Static dim size -> static/dynamic dim size does not need verification. + if (auto rankedSrcType = srcType.dyn_cast()) + if (!rankedSrcType.isDynamicDim(it.index())) + continue; + + // Static/dynamic dim size -> dynamic dim size does not need verification. + if (resultType.isDynamicDim(it.index())) + continue; + + Value srcDimSz = + builder.create(loc, castOp.getSource(), it.index()); + Value resultDimSz = + builder.create(loc, it.value()); + Value isSameSz = builder.create( + loc, arith::CmpIPredicate::eq, srcDimSz, resultDimSz); + result.emplace_back( + isSameSz, + builder.getStringAttr("memref::CastOp: size mismatch of dim " + + std::to_string(it.index()))); + } + + // Get source offset and strides. We do not have an op to get extract + // offsets and strides from unranked memrefs, so cast the source to a type + // with fully dynamic layout, from which we can then extract the offset and + // strides. (Rank was already verified.) + int64_t dynamicOffset = ShapedType::kDynamic; + SmallVector dynamicShape(resultType.getRank(), + ShapedType::kDynamic); + auto stridedLayout = StridedLayoutAttr::get(builder.getContext(), + dynamicOffset, dynamicShape); + auto dynStridesType = + MemRefType::get(dynamicShape, resultType.getElementType(), + stridedLayout, resultType.getMemorySpace()); + Value helperCast = + builder.create(loc, dynStridesType, castOp.getSource()); + auto metadataOp = builder.create(loc, helperCast); + + // Get result offset and strides. + int64_t resultOffset; + SmallVector resultStrides; + if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset))) + return result; + + // Check offset. + if (resultOffset != ShapedType::kDynamic) { + // Static/dynamic offset -> dynamic offset does not need verification. + Value srcOffset = metadataOp.getResult(1); + Value resultOffsetVal = + builder.create(loc, resultOffset); + Value isSameOffset = builder.create( + loc, arith::CmpIPredicate::eq, srcOffset, resultOffsetVal); + result.emplace_back(isSameOffset, builder.getStringAttr( + "memref::CastOp: offset mismatch")); + } + + // Check strides. + for (const auto &it : llvm::enumerate(resultStrides)) { + // Static/dynamic stride -> dynamic stride does not need verification. + if (it.value() == ShapedType::kDynamic) + continue; + + Value srcStride = + metadataOp.getResult(2 + resultType.getRank() + it.index()); + Value resultStrideVal = + builder.create(loc, it.value()); + Value isSameStride = builder.create( + loc, arith::CmpIPredicate::eq, srcStride, resultStrideVal); + result.emplace_back( + isSameStride, + builder.getStringAttr("memref::CastOp: stride mismatch of dim " + + std::to_string(it.index()))); + } + + return result; + } +}; + struct ExpandShapeOpInterface : public RuntimeVerifiableOpInterface::ExternalModel { @@ -51,8 +156,8 @@ builder.create(loc, 0)); result.emplace_back( isModZero, - builder.getStringAttr("static result dims in reassoc group do not " - "divide src dim evenly")); + builder.getStringAttr("memref::ExpandShapeOp: static result dims in " + "reassoc group do not divide src dim evenly")); } return result; @@ -65,6 +170,7 @@ void mlir::memref::registerRuntimeVerifiableOpInterfaceExternalModels( DialectRegistry ®istry) { registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) { + CastOp::attachInterface(*ctx); ExpandShapeOp::attachInterface(*ctx); // Load additional dialects of which ops may get created. diff --git a/mlir/test/Dialect/MemRef/runtime-verification.mlir b/mlir/test/Dialect/MemRef/runtime-verification.mlir --- a/mlir/test/Dialect/MemRef/runtime-verification.mlir +++ b/mlir/test/Dialect/MemRef/runtime-verification.mlir @@ -7,7 +7,7 @@ // CHECK-DAG: %[[dim:.*]] = memref.dim %[[m]], %[[c0]] // CHECK: %[[mod:.*]] = arith.remsi %[[dim]], %[[c5]] // CHECK: %[[cmpi:.*]] = arith.cmpi eq, %[[mod]], %[[c0]] -// CHECK: cf.assert %[[cmpi]], "static result dims in reassoc group do not divide src dim evenly" +// CHECK: cf.assert %[[cmpi]], "memref::ExpandShapeOp: static result dims in reassoc group do not divide src dim evenly" func.func @expand_shape(%m: memref) -> memref { %0 = memref.expand_shape %m [[0, 1]] : memref into memref return %0 : memref diff --git a/mlir/test/Integration/Dialect/Memref/cast-runtime-verification.mlir b/mlir/test/Integration/Dialect/Memref/cast-runtime-verification.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/Dialect/Memref/cast-runtime-verification.mlir @@ -0,0 +1,46 @@ +// RUN: mlir-opt %s -generate-runtime-verification -convert-memref-to-llvm \ +// RUN: -convert-cf-to-llvm="abort-on-failed-assert=0" \ +// RUN: -convert-func-to-llvm -reconcile-unrealized-casts | \ +// RUN: mlir-cpu-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_lib_dir/libmlir_runner_utils%shlibext 2>&1 | \ +// RUN: FileCheck %s + +func.func @cast_to_static_dim(%m: memref) -> memref<10xf32> { + %0 = memref.cast %m : memref to memref<10xf32> + return %0 : memref<10xf32> +} + +func.func @cast_to_ranked(%m: memref<*xf32>) -> memref { + %0 = memref.cast %m : memref<*xf32> to memref + return %0 : memref +} + +func.func @cast_to_static_strides(%m: memref>) + -> memref> { + %0 = memref.cast %m : memref> + to memref> + return %0 : memref> +} + +func.func @main() { + // All casts inside the called functions are invalid at runtime. + %alloc = memref.alloc() : memref<5xf32> + + // CHECK: memref::CastOp: size mismatch of dim 0 + %1 = memref.cast %alloc : memref<5xf32> to memref + func.call @cast_to_static_dim(%1) : (memref) -> (memref<10xf32>) + + // CHECK-NEXT: memref::CastOp: rank mismatch + %3 = memref.cast %alloc : memref<5xf32> to memref<*xf32> + func.call @cast_to_ranked(%3) : (memref<*xf32>) -> (memref) + + // CHECK-NEXT: memref::CastOp: offset mismatch + // CHECK-NEXT: memref::CastOp: stride mismatch of dim 0 + %4 = memref.cast %alloc + : memref<5xf32> to memref> + func.call @cast_to_static_strides(%4) + : (memref>) + -> (memref>) + + return +}