diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td @@ -48,6 +48,19 @@ "::mlir::transform::TransformResults &":$transformResults, "::mlir::transform::TransformState &":$state )>, + InterfaceMethod< + /*desc=*/[{ + Indicates whether the op instance allows its handle operands to be + associated with the same payload operations. + }], + /*returnType=*/"bool", + /*name=*/"allowsRepeatedHandleOperands", + /*arguments=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return false; + }] + >, ]; let extraSharedClassDeclaration = [{ diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td @@ -210,7 +210,7 @@ } def MergeHandlesOp : TransformDialectOp<"merge_handles", - [DeclareOpInterfaceMethods, + [DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, SameOperandsAndResultType]> { let summary = "Merges handles into one pointing to the union of payload ops"; diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp @@ -189,7 +189,8 @@ for (OpOperand &target : transform->getOpOperands()) { // If the operand uses an invalidated handle, report it. auto it = invalidatedHandles.find(target.get()); - if (it != invalidatedHandles.end()) + if (!transform.allowsRepeatedHandleOperands() && + it != invalidatedHandles.end()) return it->getSecond()(transform->getLoc()), failure(); // Invalidate handles pointing to the operations nested in the operation @@ -201,6 +202,7 @@ if (llvm::any_of(effects, consumesTarget)) recordHandleInvalidation(target); } + return success(); } diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -449,6 +449,11 @@ return DiagnosedSilenceableFailure::success(); } +bool transform::MergeHandlesOp::allowsRepeatedHandleOperands() { + // Handles may be the same if deduplicating is enabled. + return getDeduplicate(); +} + void transform::MergeHandlesOp::getEffects( SmallVectorImpl &effects) { consumesHandle(getHandles(), effects); diff --git a/mlir/test/Dialect/Transform/expensive-checks.mlir b/mlir/test/Dialect/Transform/expensive-checks.mlir --- a/mlir/test/Dialect/Transform/expensive-checks.mlir +++ b/mlir/test/Dialect/Transform/expensive-checks.mlir @@ -99,3 +99,17 @@ transform.test_consume_operand %1, %2 } } + +// ----- + +// Deduplication attribute allows "merge_handles" to take repeated operands. + +module { + + transform.sequence failures(propagate) { + ^bb0(%0: !pdl.operation): + %1 = transform.test_copy_payload %0 + %2 = transform.test_copy_payload %0 + transform.merge_handles %1, %2 { deduplicate } : !pdl.operation + } +}