diff --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md --- a/mlir/docs/OpDefinitions.md +++ b/mlir/docs/OpDefinitions.md @@ -400,6 +400,10 @@ [DeclareOpInterfaceMethods]> { ... } ``` +A verification method can also be specified on the `OpInterface` by setting +`verify`. Setting `verify` results in the generated trait having a `verifyTrait` +method that is applied to all operations implementing the trait. + ### Builder methods For each operation, there are a few builders automatically generated based on diff --git a/mlir/include/mlir/Analysis/InferTypeOpInterface.h b/mlir/include/mlir/Analysis/InferTypeOpInterface.h --- a/mlir/include/mlir/Analysis/InferTypeOpInterface.h +++ b/mlir/include/mlir/Analysis/InferTypeOpInterface.h @@ -72,8 +72,6 @@ Attribute attr; }; -#include "mlir/Analysis/InferTypeOpInterface.h.inc" - namespace detail { // Helper function to infer return tensor returns types given element and shape // inference function. @@ -89,8 +87,14 @@ MLIRContext *context, Optional location, ValueRange operands, ArrayRef attributes, RegionRange regions, SmallVectorImpl &inferedReturnTypes); + +/// Verifies that the inferred result types match the actual result types for +/// the op. Precondition: op implements InferTypeOpInterface. +LogicalResult verifyInferredResultTypes(Operation *op); } // namespace detail +#include "mlir/Analysis/InferTypeOpInterface.h.inc" + namespace OpTrait { /// Tensor type inference trait that constructs a tensor from the infered diff --git a/mlir/include/mlir/Analysis/InferTypeOpInterface.td b/mlir/include/mlir/Analysis/InferTypeOpInterface.td --- a/mlir/include/mlir/Analysis/InferTypeOpInterface.td +++ b/mlir/include/mlir/Analysis/InferTypeOpInterface.td @@ -60,6 +60,10 @@ }] >, ]; + + let verify = [{ + return detail::verifyInferredResultTypes($_op); + }]; } def InferShapedTypeOpInterface : OpInterface<"InferShapedTypeOpInterface"> { diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1411,8 +1411,12 @@ // OpInterfaceTrait corresponds to a specific 'OpInterface' class defined in // C++. The purpose to wrap around C++ symbol string with this class is to make // interfaces specified for ops in TableGen less alien and more integrated. -class OpInterfaceTrait : NativeOpTrait<""> { +class OpInterfaceTrait : NativeOpTrait<""> { let trait = name # "::Trait"; + + // Specify the body of the verification function. `$_op` will be replaced with + // the operation being verified. + code verify = verifyBody; } // This class represents a single, optionally static, interface method. diff --git a/mlir/include/mlir/TableGen/OpInterfaces.h b/mlir/include/mlir/TableGen/OpInterfaces.h --- a/mlir/include/mlir/TableGen/OpInterfaces.h +++ b/mlir/include/mlir/TableGen/OpInterfaces.h @@ -86,6 +86,9 @@ // Return the description of this method if it has one. llvm::Optional getDescription() const; + // Return the verify method body if it has one. + llvm::Optional getVerify() const; + private: // The TableGen definition of this interface. const llvm::Record *def; diff --git a/mlir/lib/Analysis/InferTypeOpInterface.cpp b/mlir/lib/Analysis/InferTypeOpInterface.cpp --- a/mlir/lib/Analysis/InferTypeOpInterface.cpp +++ b/mlir/lib/Analysis/InferTypeOpInterface.cpp @@ -45,3 +45,17 @@ } return success(); } + +LogicalResult mlir::detail::verifyInferredResultTypes(Operation *op) { + SmallVector inferedReturnTypes; + auto retTypeFn = cast(op); + if (failed(retTypeFn.inferReturnTypes(op->getContext(), op->getLoc(), + op->getOperands(), op->getAttrs(), + op->getRegions(), inferedReturnTypes))) + return failure(); + SmallVector resultTypes(op->getResultTypes()); + if (!retTypeFn.isCompatibleReturnTypes(inferedReturnTypes, resultTypes)) + return op->emitOpError( + "inferred type incompatible with return type of operation"); + return success(); +} diff --git a/mlir/lib/TableGen/OpInterfaces.cpp b/mlir/lib/TableGen/OpInterfaces.cpp --- a/mlir/lib/TableGen/OpInterfaces.cpp +++ b/mlir/lib/TableGen/OpInterfaces.cpp @@ -85,3 +85,9 @@ auto value = def->getValueAsString("description"); return value.empty() ? llvm::Optional() : value; } + +// Return the body for this method if it has one. +llvm::Optional OpInterface::getVerify() const { + auto value = def->getValueAsString("verify"); + return value.empty() ? llvm::Optional() : value; +} diff --git a/mlir/test/lib/TestDialect/TestPatterns.cpp b/mlir/test/lib/TestDialect/TestPatterns.cpp --- a/mlir/test/lib/TestDialect/TestPatterns.cpp +++ b/mlir/test/lib/TestDialect/TestPatterns.cpp @@ -103,26 +103,6 @@ }; return; } - - // Verification check. - // TODO: Move to ops that implement type infer interface. - getFunction().walk([this](Operation *op) -> void { - auto retTypeFn = dyn_cast(op); - if (!retTypeFn) - return; - auto *context = &getContext(); - SmallVector inferedReturnTypes; - if (failed(retTypeFn.inferReturnTypes( - context, op->getLoc(), op->getOperands(), op->getAttrs(), - op->getRegions(), inferedReturnTypes))) - return; - SmallVector resultTypes(op->getResultTypes()); - if (!retTypeFn.isCompatibleReturnTypes(inferedReturnTypes, resultTypes)) { - op->emitOpError( - "inferred type incompatible with return type of operation"); - return; - } - }); } }; } // end anonymous namespace diff --git a/mlir/test/mlir-tblgen/return-types.mlir b/mlir/test/mlir-tblgen/return-types.mlir --- a/mlir/test/mlir-tblgen/return-types.mlir +++ b/mlir/test/mlir-tblgen/return-types.mlir @@ -23,7 +23,6 @@ // ----- -// CHECK-LABEL: testReturnTypeOpInterface func @testReturnTypeOpInterface(%arg0 : tensor<10xf32>) { // expected-error@+1 {{incompatible with return type}} %bad = "test.op_with_infer_type_if"(%arg0, %arg0) : (tensor<10xf32>, tensor<10xf32>) -> tensor<*xf32> @@ -32,7 +31,6 @@ // ----- -// CHECK-LABEL: testReturnTypeOpInterface func @testReturnTypeOpInterfaceMismatch(%arg0 : tensor<10xf32>, %arg1 : tensor<20xf32>) { // expected-error@+1 {{operand type mismatch}} %bad = "test.op_with_infer_type_if"(%arg0, %arg1) : (tensor<10xf32>, tensor<20xf32>) -> tensor<*xf32> diff --git a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp --- a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp +++ b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp @@ -12,6 +12,7 @@ #include "DocGenUtilities.h" #include "mlir/Support/STLExtras.h" +#include "mlir/TableGen/Format.h" #include "mlir/TableGen/GenInfo.h" #include "mlir/TableGen/OpInterfaces.h" #include "llvm/ADT/SmallVector.h" @@ -152,6 +153,12 @@ // Insert the default implementation for any methods. for (auto &method : interface.getMethods()) { + // Flag interface methods named verifyTrait. + if (method.getName() == "verifyTrait") + PrintFatalError( + formatv("'verifyTrait' method cannot be specified as interface " + "method for '{0}'; set 'verify' on OpInterfaceTrait instead", + interfaceName)); auto defaultImpl = method.getDefaultImplementation(); if (!defaultImpl) continue; @@ -162,6 +169,13 @@ os << " {\n" << defaultImpl.getValue() << " }\n"; } + tblgen::FmtContext traitCtx; + traitCtx.withOp("op"); + if (auto verify = interface.getVerify()) { + os << " static LogicalResult verifyTrait(Operation* op) {\n" + << tblgen::tgfmt(*verify, &traitCtx) << "\n }\n"; + } + os << " };\n"; }