diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td @@ -121,6 +121,28 @@ let assemblyFormat = "$target attr-dict"; } +def MergeHandlesOp : TransformDialectOp<"merge_handles", + [DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]> { + let summary = "Merges handles into one pointing to the union of payload ops"; + let description = [{ + Creates a new Transform IR handle value that points to the same Payload IR + operations as the operand handles. The Payload IR operations are listed + in the same order as they are in the operand handles, grouped by operand + handle, e.g., all Payload IR operations associated with the first handle + come first, then all Payload IR operations associated with the second handle + and so on. If `deduplicate` is set, do not add the given Payload IR + operation more than once to the final list regardless of it coming from the + same or different handles. Consumes the operands and produces a new handle. + }]; + + let arguments = (ins Variadic:$handles, + UnitAttr:$deduplicate); + let results = (outs PDL_Operation:$result); + let assemblyFormat = "($deduplicate^)? $handles attr-dict"; + let hasFolder = 1; +} + def PDLMatchOp : TransformDialectOp<"pdl_match", [DeclareOpInterfaceMethods]> { let summary = "Finds ops that match the named PDL pattern"; diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -286,6 +286,52 @@ return DiagnosedSilenceableFailure::success(); } +//===----------------------------------------------------------------------===// +// MergeHandlesOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure +transform::MergeHandlesOp::apply(transform::TransformResults &results, + transform::TransformState &state) { + SmallVector operations; + for (Value operand : getHandles()) + llvm::append_range(operations, state.getPayloadOps(operand)); + if (!getDeduplicate()) { + results.set(getResult().cast(), operations); + return DiagnosedSilenceableFailure::success(); + } + + SetVector uniqued(operations.begin(), operations.end()); + results.set(getResult().cast(), uniqued.getArrayRef()); + return DiagnosedSilenceableFailure::success(); +} + +void transform::MergeHandlesOp::getEffects( + SmallVectorImpl &effects) { + for (Value operand : getHandles()) { + effects.emplace_back(MemoryEffects::Read::get(), operand, + transform::TransformMappingResource::get()); + effects.emplace_back(MemoryEffects::Free::get(), operand, + transform::TransformMappingResource::get()); + } + effects.emplace_back(MemoryEffects::Allocate::get(), getResult(), + transform::TransformMappingResource::get()); + effects.emplace_back(MemoryEffects::Write::get(), getResult(), + transform::TransformMappingResource::get()); + + // There are no effects on the Payload IR as this is only a handle + // manipulation. +} + +OpFoldResult transform::MergeHandlesOp::fold(ArrayRef operands) { + if (getDeduplicate() || getHandles().size() != 1) + return {}; + + // If deduplication is not required and there is only one operand, it can be + // used directly instead of merging. + return getHandles().front(); +} + //===----------------------------------------------------------------------===// // PDLMatchOp //===----------------------------------------------------------------------===// diff --git a/mlir/python/mlir/dialects/_transform_ops_ext.py b/mlir/python/mlir/dialects/_transform_ops_ext.py --- a/mlir/python/mlir/dialects/_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_transform_ops_ext.py @@ -28,6 +28,21 @@ ip=ip) +class MergeHandlesOp: + + def __init__(self, + handles: Sequence[Union[Operation, Value]], + *, + deduplicate: bool = False, + loc=None, + ip=None): + super().__init__( + pdl.OperationType.get(), [_get_op_result_or_value(h) for h in handles], + deduplicate=deduplicate, + loc=loc, + ip=ip) + + class PDLMatchOp: def __init__(self, diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir --- a/mlir/test/Dialect/Transform/test-interpreter.mlir +++ b/mlir/test/Dialect/Transform/test-interpreter.mlir @@ -460,3 +460,42 @@ %1:2 = transform.test_correct_number_of_multi_results %0 } } + +// ----- + +// Expecting to match all operations by merging the handles that matched addi +// and subi separately. +func.func @foo(%arg0: index) { + // expected-remark @below {{matched}} + %0 = arith.addi %arg0, %arg0 : index + // expected-remark @below {{matched}} + %1 = arith.subi %arg0, %arg0 : index + // expected-remark @below {{matched}} + %2 = arith.addi %0, %1 : index + return +} + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + pdl.pattern @addi : benefit(1) { + %0 = pdl.operands + %1 = pdl.types + %2 = pdl.operation "arith.addi"(%0 : !pdl.range) -> (%1 : !pdl.range) + pdl.rewrite %2 with "transform.dialect" + } + pdl.pattern @subi : benefit(1) { + %0 = pdl.operands + %1 = pdl.types + %2 = pdl.operation "arith.subi"(%0 : !pdl.range) -> (%1 : !pdl.range) + pdl.rewrite %2 with "transform.dialect" + } + + transform.sequence %arg0 { + ^bb0(%arg1: !pdl.operation): + %0 = pdl_match @addi in %arg1 + %1 = pdl_match @subi in %arg1 + %2 = merge_handles %0, %1 + test_print_remark_at_operand %2, "matched" + } +} + diff --git a/mlir/test/python/dialects/transform.py b/mlir/test/python/dialects/transform.py --- a/mlir/test/python/dialects/transform.py +++ b/mlir/test/python/dialects/transform.py @@ -82,3 +82,15 @@ # CHECK: transform.sequence # CHECK: ^{{.*}}(%[[ARG1:.+]]: !pdl.operation): # CHECK: = get_closest_isolated_parent %[[ARG1]] + + +@run +def testMergeHandlesOp(): + sequence = transform.SequenceOp() + with InsertionPoint(sequence.body): + transform.MergeHandlesOp([sequence.bodyTarget]) + transform.YieldOp() + # CHECK-LABEL: TEST: testMergeHandlesOp + # CHECK: transform.sequence + # CHECK: ^{{.*}}(%[[ARG1:.+]]: !pdl.operation): + # CHECK: = merge_handles %[[ARG1]]