diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td @@ -189,6 +189,27 @@ "$target attr-dict `:` functional-type(operands, results)"; } +def GetConsumersOfResult : TransformDialectOp<"get_consumers_of_result", + [DeclareOpInterfaceMethods, + NavigationTransformOpTrait, MemoryEffectsOpInterface]> { + let summary = "Get handle to the consumers of this operation's result number"; + let description = [{ + The handle defined by this Transform op corresponds to all operations that + consume the SSA value defined by the `target` and `result_number` + arguments. + This operation applies to a single payload operation, otherwise it + definitely fails. + The return handle points to the consuming operations operations, which can + be empty. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target, + I64Attr:$result_number); + let results = (outs TransformHandleTypeInterface:$consumers); + let assemblyFormat = "$target `[` $result_number `]` attr-dict `:` " + "functional-type(operands, results)"; +} + def GetProducerOfOperand : TransformDialectOp<"get_producer_of_operand", [DeclareOpInterfaceMethods, NavigationTransformOpTrait, MemoryEffectsOpInterface]> { diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -399,6 +399,31 @@ return DiagnosedSilenceableFailure::success(); } +//===----------------------------------------------------------------------===// +// GetConsumersOfResult +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure +transform::GetConsumersOfResult::apply(transform::TransformResults &results, + transform::TransformState &state) { + int64_t resultNumber = getResultNumber(); + ArrayRef payloadOps = state.getPayloadOps(getTarget()); + if (payloadOps.empty()) { + results.set(getResult().cast(), {}); + return DiagnosedSilenceableFailure::success(); + } + if (payloadOps.size() != 1) + return emitDefiniteFailure() + << "handle must be mapped to exactly one payload op"; + + Operation *target = payloadOps.front(); + if (target->getNumResults() <= resultNumber) + return emitDefiniteFailure() << "result number overflow"; + results.set(getResult().cast(), + llvm::to_vector(target->getResult(resultNumber).getUsers())); + return DiagnosedSilenceableFailure::success(); +} + //===----------------------------------------------------------------------===// // GetProducerOfOperand //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir --- a/mlir/test/Dialect/Transform/test-interpreter.mlir +++ b/mlir/test/Dialect/Transform/test-interpreter.mlir @@ -775,6 +775,53 @@ // ----- +func.func @get_consumer(%arg0: index, %arg1: index) { + %0 = arith.muli %arg0, %arg1 : index + // expected-remark @below {{found addi}} + arith.addi %0, %arg1 : index + return +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %muli = transform.structured.match ops{["arith.muli"]} in %arg1 + %addi = get_consumers_of_result %muli[0] : (!pdl.operation) -> !pdl.operation + transform.test_print_remark_at_operand %addi, "found addi" : !pdl.operation +} + +// ----- + +func.func @get_consumer_fail_1(%arg0: index, %arg1: index) { + %0 = arith.muli %arg0, %arg1 : index + %1 = arith.muli %arg0, %arg1 : index + return +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %muli = transform.structured.match ops{["arith.muli"]} in %arg1 + // expected-error @below {{handle must be mapped to exactly one payload op}} + %bbarg = get_consumers_of_result %muli[0] : (!pdl.operation) -> !pdl.operation + +} + +// ----- + +func.func @get_consumer_fail_2(%arg0: index, %arg1: index) { + %0 = arith.muli %arg0, %arg1 : index + return +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %muli = transform.structured.match ops{["arith.muli"]} in %arg1 + // expected-error @below {{result number overflow}} + %bbarg = get_consumers_of_result %muli[1] : (!pdl.operation) -> !pdl.operation + +} + +// ----- + func.func @split_handles(%a: index, %b: index, %c: index) { %0 = arith.muli %a, %b : index %1 = arith.muli %a, %c : index