diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h b/mlir/include/mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h @@ -0,0 +1,21 @@ +//===- RuntimeOpVerification.h - Op Verification ----------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_MEMREF_RUNTIMEOPVERIFICATION_H +#define MLIR_DIALECT_MEMREF_RUNTIMEOPVERIFICATION_H + +namespace mlir { +class DialectRegistry; + +namespace memref { +void registerRuntimeVerifiableOpInterfaceExternalModels( + DialectRegistry ®istry); +} // namespace memref +} // namespace mlir + +#endif // MLIR_DIALECT_MEMREF_RUNTIMEOPVERIFICATION_H diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -45,6 +45,7 @@ #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h" +#include "mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h" #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" #include "mlir/Dialect/OpenACC/OpenACC.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" @@ -130,6 +131,7 @@ registry); linalg::registerBufferizableOpInterfaceExternalModels(registry); linalg::registerTilingInterfaceExternalModels(registry); + memref::registerRuntimeVerifiableOpInterfaceExternalModels(registry); scf::registerBufferizableOpInterfaceExternalModels(registry); shape::registerBufferizableOpInterfaceExternalModels(registry); sparse_tensor::registerBufferizableOpInterfaceExternalModels(registry); diff --git a/mlir/include/mlir/Interfaces/CMakeLists.txt b/mlir/include/mlir/Interfaces/CMakeLists.txt --- a/mlir/include/mlir/Interfaces/CMakeLists.txt +++ b/mlir/include/mlir/Interfaces/CMakeLists.txt @@ -8,6 +8,7 @@ add_mlir_interface(InferTypeOpInterface) add_mlir_interface(LoopLikeInterface) add_mlir_interface(ParallelCombiningOpInterface) +add_mlir_interface(RuntimeVerifiableOpInterface) add_mlir_interface(ShapedOpInterfaces) add_mlir_interface(SideEffectInterfaces) add_mlir_interface(TilingInterface) diff --git a/mlir/include/mlir/Interfaces/RuntimeVerifiableOpInterface.h b/mlir/include/mlir/Interfaces/RuntimeVerifiableOpInterface.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Interfaces/RuntimeVerifiableOpInterface.h @@ -0,0 +1,17 @@ +//===- RuntimeVerifiableOpInterface.h - Op Verification ---------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_INTERFACES_RUNTIMEVERIFIABLEOPINTERFACE_H_ +#define MLIR_INTERFACES_RUNTIMEVERIFIABLEOPINTERFACE_H_ + +#include "mlir/IR/OpDefinition.h" + +/// Include the generated interface declarations. +#include "mlir/Interfaces/RuntimeVerifiableOpInterface.h.inc" + +#endif // MLIR_INTERFACES_RUNTIMEVERIFIABLEOPINTERFACE_H_ diff --git a/mlir/include/mlir/Interfaces/RuntimeVerifiableOpInterface.td b/mlir/include/mlir/Interfaces/RuntimeVerifiableOpInterface.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Interfaces/RuntimeVerifiableOpInterface.td @@ -0,0 +1,40 @@ +//===- RuntimeVerifiableOpInterface.td - Op Verification ---*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_INTERFACES_RUNTIMEVERIFIABLEOPINTERFACE +#define MLIR_INTERFACES_RUNTIMEVERIFIABLEOPINTERFACE + +include "mlir/IR/OpBase.td" + +def RuntimeVerifiableOpInterface : OpInterface<"RuntimeVerifiableOpInterface"> { + let description = [{ + Implementations of this interface generate IR for runtime op verification. + + Incorrect op usage can often be caught by op verifiers based on static + program information. However, in the absence of static program information, + it can remain undetected at compile time (e.g., in case of dynamic memref + strides instead of static memref strides). Such cases can be checked at + runtime. The op-specific checks are generated by this interface. + }]; + let cppNamespace = "::mlir"; + + let methods = [ + InterfaceMethod< + /*desc=*/[{ + Generate IR to verify this op at runtime, aborting runtime execution if + verification fails. + }], + /*retTy=*/"void", + /*methodName=*/"generateRuntimeVerification", + /*args=*/(ins "::mlir::OpBuilder &":$builder, + "::mlir::Location":$loc) + >, + ]; +} + +#endif // MLIR_INTERFACES_RUNTIMEVERIFIABLEOPINTERFACE diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h --- a/mlir/include/mlir/Transforms/Passes.h +++ b/mlir/include/mlir/Transforms/Passes.h @@ -64,6 +64,9 @@ /// Creates a pass to perform common sub expression elimination. std::unique_ptr createCSEPass(); +/// Creates a pass that generates IR to verify ops at runtime. +std::unique_ptr createGenerateRuntimeVerificationPass(); + /// Creates a loop invariant code motion pass that hoists loop invariant /// instructions out of the loop. std::unique_ptr createLoopInvariantCodeMotionPass(); diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td --- a/mlir/include/mlir/Transforms/Passes.td +++ b/mlir/include/mlir/Transforms/Passes.td @@ -77,6 +77,16 @@ ]; } +def GenerateRuntimeVerification : Pass<"generate-runtime-verification"> { + let summary = "Generate additional runtime op verification checks"; + let description = [{ + This pass generates op-specific runtime checks using the + `RuntimeVerifiableOpInterface`. It can be run for debugging purposes after + passes that are suspected to introduce faulty IR. + }]; + let constructor = "mlir::createGenerateRuntimeVerificationPass()"; +} + def Inliner : Pass<"inline"> { let summary = "Inline function calls"; let constructor = "mlir::createInlinerPass()"; diff --git a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt @@ -7,6 +7,7 @@ MultiBuffer.cpp NormalizeMemRefs.cpp ResolveShapedTypeResultDims.cpp + RuntimeOpVerification.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/MemRef diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp @@ -0,0 +1,70 @@ +//===- RuntimeOpVerification.cpp - Op Verification ------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Interfaces/RuntimeVerifiableOpInterface.h" + +namespace mlir { +namespace memref { +namespace { +struct ExpandShapeOpInterface + : public RuntimeVerifiableOpInterface::ExternalModel { + void generateRuntimeVerification(Operation *op, OpBuilder &builder, + Location loc) const { + auto expandShapeOp = cast(op); + + // Verify that the expanded dim sizes are a product of the collapsed dim + // size. + for (auto it : llvm::enumerate(expandShapeOp.getReassociationIndices())) { + Value srcDimSz = + builder.create(loc, expandShapeOp.getSrc(), it.index()); + int64_t groupSz = 1; + bool foundDynamicDim = false; + for (int64_t resultDim : it.value()) { + if (expandShapeOp.getResultType().isDynamicDim(resultDim)) { + // Keep this assert here in case the op is extended in the future. + assert(!foundDynamicDim && + "more than one dynamic dim found in reassoc group"); + foundDynamicDim = true; + continue; + } + groupSz *= expandShapeOp.getResultType().getDimSize(resultDim); + } + Value staticResultDimSz = + builder.create(loc, groupSz); + // staticResultDimSz must divide srcDimSz evenly. + Value mod = + builder.create(loc, srcDimSz, staticResultDimSz); + Value isModZero = builder.create( + loc, arith::CmpIPredicate::eq, mod, + builder.create(loc, 0)); + builder.create( + loc, isModZero, + "static result dims in reassoc group do not divide src dim evenly"); + } + } +}; +} // namespace +} // namespace memref +} // namespace mlir + +void mlir::memref::registerRuntimeVerifiableOpInterfaceExternalModels( + DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) { + ExpandShapeOp::attachInterface(*ctx); + + // Load additional dialects of which ops may get created. + ctx->loadDialect(); + }); +} diff --git a/mlir/lib/Interfaces/CMakeLists.txt b/mlir/lib/Interfaces/CMakeLists.txt --- a/mlir/lib/Interfaces/CMakeLists.txt +++ b/mlir/lib/Interfaces/CMakeLists.txt @@ -10,6 +10,7 @@ InferTypeOpInterface.cpp LoopLikeInterface.cpp ParallelCombiningOpInterface.cpp + RuntimeVerifiableOpInterface.cpp ShapedOpInterfaces.cpp SideEffectInterfaces.cpp TilingInterface.cpp @@ -44,6 +45,7 @@ add_mlir_interface_library(InferTypeOpInterface) add_mlir_interface_library(LoopLikeInterface) add_mlir_interface_library(ParallelCombiningOpInterface) +add_mlir_interface_library(RuntimeVerifiableOpInterface) add_mlir_interface_library(ShapedOpInterfaces) add_mlir_interface_library(SideEffectInterfaces) add_mlir_interface_library(TilingInterface) diff --git a/mlir/lib/Interfaces/RuntimeVerifiableOpInterface.cpp b/mlir/lib/Interfaces/RuntimeVerifiableOpInterface.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Interfaces/RuntimeVerifiableOpInterface.cpp @@ -0,0 +1,17 @@ +//===- RuntimeVerifiableOpInterface.cpp - Op Verification -----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Interfaces/RuntimeVerifiableOpInterface.h" + +namespace mlir { +class Location; +class OpBuilder; +} // namespace mlir + +/// Include the definitions of the interface. +#include "mlir/Interfaces/RuntimeVerifiableOpInterface.cpp.inc" diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt --- a/mlir/lib/Transforms/CMakeLists.txt +++ b/mlir/lib/Transforms/CMakeLists.txt @@ -4,6 +4,7 @@ Canonicalizer.cpp ControlFlowSink.cpp CSE.cpp + GenerateRuntimeVerification.cpp Inliner.cpp LocationSnapshot.cpp LoopInvariantCodeMotion.cpp @@ -26,6 +27,7 @@ MLIRCopyOpInterface MLIRLoopLikeInterface MLIRPass + MLIRRuntimeVerifiableOpInterface MLIRSideEffectInterfaces MLIRSupport MLIRTransformUtils diff --git a/mlir/lib/Transforms/GenerateRuntimeVerification.cpp b/mlir/lib/Transforms/GenerateRuntimeVerification.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Transforms/GenerateRuntimeVerification.cpp @@ -0,0 +1,40 @@ +//===- RuntimeOpVerification.cpp - Op Verification ------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Transforms/Passes.h" + +#include "mlir/IR/Builders.h" +#include "mlir/IR/Operation.h" +#include "mlir/Interfaces/RuntimeVerifiableOpInterface.h" + +namespace mlir { +#define GEN_PASS_DEF_GENERATERUNTIMEVERIFICATION +#include "mlir/Transforms/Passes.h.inc" +} // namespace mlir + +using namespace mlir; + +namespace { +struct GenerateRuntimeVerificationPass + : public impl::GenerateRuntimeVerificationBase< + GenerateRuntimeVerificationPass> { + void runOnOperation() override; +}; +} // namespace + +void GenerateRuntimeVerificationPass::runOnOperation() { + getOperation()->walk([&](RuntimeVerifiableOpInterface verifiableOp) { + OpBuilder builder(getOperation()->getContext()); + builder.setInsertionPoint(verifiableOp); + verifiableOp.generateRuntimeVerification(builder, verifiableOp.getLoc()); + }); +} + +std::unique_ptr mlir::createGenerateRuntimeVerificationPass() { + return std::make_unique(); +} diff --git a/mlir/test/Dialect/MemRef/runtime-verification.mlir b/mlir/test/Dialect/MemRef/runtime-verification.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/MemRef/runtime-verification.mlir @@ -0,0 +1,14 @@ +// RUN: mlir-opt %s -generate-runtime-verification -cse | FileCheck %s + +// CHECK-LABEL: func @expand_shape( +// CHECK-SAME: %[[m:.*]]: memref +// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[c5:.*]] = arith.constant 5 : index +// CHECK-DAG: %[[dim:.*]] = memref.dim %[[m]], %[[c0]] +// CHECK: %[[mod:.*]] = arith.remsi %[[dim]], %[[c5]] +// CHECK: %[[cmpi:.*]] = arith.cmpi eq, %[[mod]], %[[c0]] +// CHECK: cf.assert %[[cmpi]], "static result dims in reassoc group do not divide src dim evenly" +func.func @expand_shape(%m: memref) -> memref { + %0 = memref.expand_shape %m [[0, 1]] : memref into memref + return %0 : memref +} diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -1037,6 +1037,13 @@ deps = [":OpBaseTdFiles"], ) +td_library( + name = "RuntimeVerifiableOpInterfaceTdFiles", + srcs = ["include/mlir/Interfaces/RuntimeVerifiableOpInterface.td"], + includes = ["include"], + deps = [":OpBaseTdFiles"], +) + td_library( name = "SideEffectInterfacesTdFiles", srcs = [ @@ -2992,6 +2999,18 @@ ], ) +cc_library( + name = "RuntimeVerifiableOpInterface", + srcs = ["lib/Interfaces/RuntimeVerifiableOpInterface.cpp"], + hdrs = ["include/mlir/Interfaces/RuntimeVerifiableOpInterface.h"], + includes = ["include"], + deps = [ + ":IR", + ":RuntimeVerifiableOpInterfaceIncGen", + "//llvm:Support", + ], +) + cc_library( name = "VectorInterfaces", srcs = ["lib/Interfaces/VectorInterfaces.cpp"], @@ -5715,6 +5734,24 @@ deps = [":ParallelCombiningOpInterfaceTdFiles"], ) +gentbl_cc_library( + name = "RuntimeVerifiableOpInterfaceIncGen", + strip_include_prefix = "include", + tbl_outs = [ + ( + ["-gen-op-interface-decls"], + "include/mlir/Interfaces/RuntimeVerifiableOpInterface.h.inc", + ), + ( + ["-gen-op-interface-defs"], + "include/mlir/Interfaces/RuntimeVerifiableOpInterface.cpp.inc", + ), + ], + tblgen = ":mlir-tblgen", + td_file = "include/mlir/Interfaces/RuntimeVerifiableOpInterface.td", + deps = [":RuntimeVerifiableOpInterfaceTdFiles"], +) + gentbl_cc_library( name = "VectorInterfacesIncGen", strip_include_prefix = "include", @@ -5818,6 +5855,7 @@ ":LoopLikeInterface", ":Pass", ":Rewrite", + ":RuntimeVerifiableOpInterface", ":SideEffectInterfaces", ":Support", ":TransformUtils", @@ -9783,6 +9821,7 @@ ":ArithDialect", ":ArithTransforms", ":ArithUtils", + ":ControlFlowDialect", ":DialectUtils", ":FuncDialect", ":IR", @@ -9791,6 +9830,7 @@ ":MemRefDialect", ":MemRefPassIncGen", ":Pass", + ":RuntimeVerifiableOpInterface", ":TensorDialect", ":Transforms", ":VectorDialect",