diff --git a/mlir/examples/toy/Ch7/include/toy/Ops.td b/mlir/examples/toy/Ch7/include/toy/Ops.td --- a/mlir/examples/toy/Ch7/include/toy/Ops.td +++ b/mlir/examples/toy/Ch7/include/toy/Ops.td @@ -49,7 +49,8 @@ // constant operation is marked as 'NoSideEffect' as it is a pure operation // and may be removed if dead. def ConstantOp : Toy_Op<"constant", - [NoSideEffect, DeclareOpInterfaceMethods]> { + [ConstantLike, NoSideEffect, + DeclareOpInterfaceMethods]> { // Provide a summary and description for this operation. This can be used to // auto-generate documentation of the operations within our dialect. let summary = "constant"; @@ -295,7 +296,7 @@ let hasFolder = 1; } -def StructConstantOp : Toy_Op<"struct_constant", [NoSideEffect]> { +def StructConstantOp : Toy_Op<"struct_constant", [ConstantLike, NoSideEffect]> { let summary = "struct constant"; let description = [{ Constant operation turns a literal struct value into an SSA value. The data diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td @@ -67,7 +67,7 @@ // ----- -def SPV_ConstantOp : SPV_Op<"constant", [NoSideEffect]> { +def SPV_ConstantOp : SPV_Op<"constant", [ConstantLike, NoSideEffect]> { let summary = "The op that declares a SPIR-V normal constant"; let description = [{ diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -796,7 +796,7 @@ //===----------------------------------------------------------------------===// def ConstantOp : Std_Op<"constant", - [NoSideEffect, DeclareOpInterfaceMethods]> { + [ConstantLike, NoSideEffect, DeclareOpInterfaceMethods]> { let summary = "constant"; let arguments = (ins AnyAttr:$value); diff --git a/mlir/include/mlir/IR/Matchers.h b/mlir/include/mlir/IR/Matchers.h --- a/mlir/include/mlir/IR/Matchers.h +++ b/mlir/include/mlir/IR/Matchers.h @@ -48,8 +48,13 @@ } }; -/// The matcher that matches a constant foldable operation that has no side -/// effect, no operands and produces a single result. +/// The matcher that matches operations that have the `ConstantLike` trait. +struct constant_op_matcher { + bool match(Operation *op) { return op->hasTrait(); } +}; + +/// The matcher that matches operations that have the `ConstantLike` trait, and +/// binds the folded attribute value. template struct constant_op_binder { AttrT *bind_value; @@ -60,20 +65,19 @@ constant_op_binder() : bind_value(nullptr) {} bool match(Operation *op) { - if (op->getNumOperands() > 0 || op->getNumResults() != 1) - return false; - if (!op->hasNoSideEffect()) + if (!op->hasTrait()) return false; + // Fold the constant to an attribute. SmallVector foldedOp; - if (succeeded(op->fold(/*operands=*/llvm::None, foldedOp))) { - if (auto attr = foldedOp.front().dyn_cast()) { - if (auto attrT = attr.dyn_cast()) { - if (bind_value) - *bind_value = attrT; - return true; - } - } + LogicalResult result = op->fold(/*operands=*/llvm::None, foldedOp); + (void)result; + assert(succeeded(result) && "expected constant to be foldable"); + + if (auto attr = foldedOp.front().get().dyn_cast()) { + if (bind_value) + *bind_value = attr; + return true; } return false; } @@ -201,8 +205,8 @@ } // end namespace detail /// Matches a constant foldable operation. -inline detail::constant_op_binder m_Constant() { - return detail::constant_op_binder(); +inline detail::constant_op_matcher m_Constant() { + return detail::constant_op_matcher(); } /// Matches a value from a constant foldable operation and writes the value to diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1549,6 +1549,8 @@ def Broadcastable : NativeOpTrait<"ResultsBroadcastableShape">; // X op Y == Y op X def Commutative : NativeOpTrait<"IsCommutative">; +// Op behaves like a constant. +def ConstantLike : NativeOpTrait<"ConstantLike">; // Op behaves like a function. def FunctionLike : NativeOpTrait<"FunctionLike">; // Op is isolated from above. diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -902,6 +902,25 @@ } }; +/// This class provides the API for a sub-set of ops that are known to be +/// constant-like. These are non-side effecting operations with one result and +/// zero operands that can always be folded to a specific attribute value. +template +class ConstantLike : public TraitBase { +public: + static LogicalResult verifyTrait(Operation *op) { + static_assert(ConcreteType::template hasTrait(), + "expected operation to produce one result"); + static_assert(ConcreteType::template hasTrait(), + "expected operation to take zero operands"); + // TODO: We should verify that the operation can always be folded, but this + // requires that the attributes of the op already be verified. We should add + // support for verifying traits "after" the operation to enable this use + // case. + return success(); + } +}; + /// This class provides the API for ops that are known to be isolated from /// above. template diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -399,7 +399,7 @@ cst->erase(); return cleanupFailure(); } - assert(matchPattern(constOp, m_Constant(&attr))); + assert(matchPattern(constOp, m_Constant())); generatedConstants.push_back(constOp); results.push_back(constOp->getResult(0)); diff --git a/mlir/lib/Transforms/Utils/FoldUtils.cpp b/mlir/lib/Transforms/Utils/FoldUtils.cpp --- a/mlir/lib/Transforms/Utils/FoldUtils.cpp +++ b/mlir/lib/Transforms/Utils/FoldUtils.cpp @@ -57,7 +57,7 @@ // Ask the dialect to materialize a constant operation for this value. if (auto *constOp = dialect->materializeConstant(builder, value, type, loc)) { assert(insertPt == builder.getInsertionPoint()); - assert(matchPattern(constOp, m_Constant(&value))); + assert(matchPattern(constOp, m_Constant())); return constOp; } diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -454,7 +454,7 @@ // ----- func @generic_fun_result_0_element_type(%arg0: memref) { - // expected-error @+1 {{'linalg.dot' op expected 3 or more operands}} + // expected-error @+1 {{'linalg.dot' op expected 3 operands, but found 2}} linalg.dot(%arg0, %arg0): memref, memref } diff --git a/mlir/test/IR/traits.mlir b/mlir/test/IR/traits.mlir --- a/mlir/test/IR/traits.mlir +++ b/mlir/test/IR/traits.mlir @@ -24,7 +24,7 @@ // ----- func @failedSameOperandAndResultElementType_no_operands() { - // expected-error@+1 {{expected 1 or more operands}} + // expected-error@+1 {{expected 2 operands, but found 0}} "test.same_operand_element_type"() : () -> tensor<1xf32> } diff --git a/mlir/test/mlir-tblgen/op-decl.td b/mlir/test/mlir-tblgen/op-decl.td --- a/mlir/test/mlir-tblgen/op-decl.td +++ b/mlir/test/mlir-tblgen/op-decl.td @@ -55,7 +55,7 @@ // CHECK: ArrayRef tblgen_operands; // CHECK: }; -// CHECK: class AOp : public Op::Impl, OpTrait::ZeroSuccessor, OpTrait::HasNoSideEffect, OpTrait::AtLeastNOperands<1>::Impl +// CHECK: class AOp : public Op::Impl, OpTrait::ZeroSuccessor, OpTrait::AtLeastNOperands<1>::Impl, OpTrait::HasNoSideEffect // CHECK: public: // CHECK: using Op::Op; // CHECK: using OperandAdaptor = AOpOperandAdaptor; diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -1523,14 +1523,6 @@ unsigned numVariadicSuccessors = op.getNumVariadicSuccessors(); addSizeCountTrait(opClass, "Successor", numSuccessors, numVariadicSuccessors); - // Add the native and interface traits. - for (const auto &trait : op.getTraits()) { - if (auto opTrait = dyn_cast(&trait)) - opClass.addTrait(opTrait->getTrait()); - else if (auto opTrait = dyn_cast(&trait)) - opClass.addTrait(opTrait->getTrait()); - } - // Add variadic size trait and normal op traits. int numOperands = op.getNumOperands(); int numVariadicOperands = op.getNumVariadicOperands(); @@ -1555,6 +1547,14 @@ break; } } + + // Add the native and interface traits. + for (const auto &trait : op.getTraits()) { + if (auto opTrait = dyn_cast(&trait)) + opClass.addTrait(opTrait->getTrait()); + else if (auto opTrait = dyn_cast(&trait)) + opClass.addTrait(opTrait->getTrait()); + } } void OpEmitter::genOpNameGetter() {