diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1931,6 +1931,9 @@ // Op attached regions have no arguments def NoRegionArguments : NativeOpTrait<"NoRegionArguments">, StructuralOpTrait; +// Indicates that an op should be verified before its traits. +def VerifyOpBeforeTraits : NativeOpTrait<"VerifyOpBeforeTraits">; + //===----------------------------------------------------------------------===// // OpInterface definitions //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -1455,6 +1455,12 @@ /// behavior to vectors/tensors, and systematize conversion between these forms. bool hasElementwiseMappableTraits(Operation *op); +/// The presence of this trait indicates an op should be verified before its +/// traits. +template +struct VerifyOpBeforeTraits + : public TraitBase {}; + } // namespace OpTrait //===----------------------------------------------------------------------===// @@ -1876,6 +1882,11 @@ static LogicalResult verifyInvariants(Operation *op) { static_assert(hasNoDataMembers(), "Op class shouldn't define new data members"); + if (ConcreteType::template hasTrait()) + return failure( + failed(cast(op).verify()) || + failed( + op_definition_impl::verifyTraits...>(op))); return failure( failed(op_definition_impl::verifyTraits...>(op)) || failed(cast(op).verify())); @@ -1887,6 +1898,12 @@ static LogicalResult verifyRegionInvariants(Operation *op) { static_assert(hasNoDataMembers(), "Op class shouldn't define new data members"); + if (ConcreteType::template hasTrait()) + return failure( + failed(cast(op).verifyRegions()) || + failed( + op_definition_impl::verifyRegionTraits...>( + op))); return failure( failed(op_definition_impl::verifyRegionTraits...>( op)) || diff --git a/mlir/test/lib/Dialect/Test/TestDialect.h b/mlir/test/lib/Dialect/Test/TestDialect.h --- a/mlir/test/lib/Dialect/Test/TestDialect.h +++ b/mlir/test/lib/Dialect/Test/TestDialect.h @@ -45,6 +45,9 @@ namespace mlir { class DLTIDialect; class RewritePatternSet; +namespace detail { +LogicalResult verifyTestPositiveValueOpInterface(Operation *op); +} // namespace detail } // namespace mlir //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -1588,6 +1588,27 @@ setResultRanges(getResult(), range); } +//===----------------------------------------------------------------------===// +// PositiveIntAttrWithInterfaceOp +//===----------------------------------------------------------------------===// + +int64_t PositiveIntAttrWithInterfaceOp::getIntValue() { + return getIntAttrAttr().cast().getInt(); +} + +LogicalResult PositiveIntAttrWithInterfaceOp::verify() { + return success(isa(getIntAttrAttr())); +} + +//===----------------------------------------------------------------------===// +// TestPositiveValueOpInterface +//===----------------------------------------------------------------------===// + +LogicalResult mlir::detail::verifyTestPositiveValueOpInterface(Operation *op) { + return success(cast(op).getIntValue() > + 0); +} + #include "TestOpEnums.cpp.inc" #include "TestOpInterfaces.cpp.inc" #include "TestTypeInterfaces.cpp.inc" diff --git a/mlir/test/lib/Dialect/Test/TestInterfaces.td b/mlir/test/lib/Dialect/Test/TestInterfaces.td --- a/mlir/test/lib/Dialect/Test/TestInterfaces.td +++ b/mlir/test/lib/Dialect/Test/TestInterfaces.td @@ -120,4 +120,13 @@ def TestConcreteEffect : TestEffect<"TestEffects::Concrete">; +def TestPositiveValueOpInterface : OpInterface<"TestPositiveValueOpInterface"> { + let cppNamespace = "::mlir"; + let methods = [ + InterfaceMethod<"Returns the int value associated with this op.", + "int64_t", "getIntValue", (ins)>, + ]; + let verify = [{ return detail::verifyTestPositiveValueOpInterface($_op); }]; +} + #endif // MLIR_TEST_DIALECT_TEST_INTERFACES diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -202,6 +202,16 @@ ); } +def PositiveIntAttrWithInterfaceOp : TEST_Op<"positive_int_attr_with_interface", + [DeclareOpInterfaceMethods, + VerifyOpBeforeTraits]> { + // This could be an I64Attr. But we check this in the hand-written verifier + // to test VerifyOpBeforeTraits. The hand-written op verifier should be + // called before the VerifyOpBeforeTraits verifier. + let arguments = (ins AnyAttr:$int_attr); + let hasVerifier = 1; +} + def I32EnumAttrOp : TEST_Op<"i32_enum_attr"> { let arguments = (ins SomeI32Enum:$attr); let results = (outs I32:$val); diff --git a/mlir/unittests/IR/InterfaceTest.cpp b/mlir/unittests/IR/InterfaceTest.cpp --- a/mlir/unittests/IR/InterfaceTest.cpp +++ b/mlir/unittests/IR/InterfaceTest.cpp @@ -11,6 +11,7 @@ #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/OwningOpRef.h" +#include "mlir/IR/Verifier.h" #include "gtest/gtest.h" #include "../../test/lib/Dialect/Test/TestAttributes.h" @@ -75,3 +76,23 @@ EXPECT_TRUE(typeSet.contains(type2)); EXPECT_FALSE(typeSet.contains(type3)); } + +TEST(InterfaceTest, VerifyOpFirst) { + MLIRContext context; + context.loadDialect(); + + // Build an invalid op. + OpBuilder builder(&context); + Location loc = builder.getUnknownLoc(); + SmallVector attrs; + attrs.push_back(NamedAttribute(builder.getStringAttr("int_attr"), + builder.getStringAttr("foo"))); + Operation *op = builder.create( + loc, + builder.getStringAttr(PositiveIntAttrWithInterfaceOp::getOperationName()), + /*operands=*/{}, /*types=*/{}, attrs); + // The verifier fails but it does not crash. (It would crash if we were to + // verify the interface first.) + LogicalResult status = verify(op); + EXPECT_FALSE(succeeded(status)); +}