diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td @@ -169,6 +169,25 @@ let assemblyFormat = "$target attr-dict"; } +def GetProducerOfOperand : TransformDialectOp<"get_producer_of_operand", + [DeclareOpInterfaceMethods, + NavigationTransformOpTrait, MemoryEffectsOpInterface]> { + let summary = "Get handle to the producer of this operation's operand number"; + let description = [{ + The handle defined by this Transform op corresponds to operation that + produces the SSA value defined by the `target` and `operand_number` + arguments. If the origin of the SSA value is not an operations (i.e. it is + a block argument), the transform silently fails. + The return handle points to only the subset of successfully produced + computational operations, which can be empty. + }]; + + let arguments = (ins PDL_Operation:$target, + I64Attr:$operand_number); + let results = (outs PDL_Operation:$parent); + let assemblyFormat = "$target `[` $operand_number `]` attr-dict"; +} + def MergeHandlesOp : TransformDialectOp<"merge_handles", [DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -386,6 +386,36 @@ return DiagnosedSilenceableFailure::success(); } +//===----------------------------------------------------------------------===// +// GetProducerOfOperand +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure +transform::GetProducerOfOperand::apply(transform::TransformResults &results, + transform::TransformState &state) { + int64_t operandNumber = getOperandNumber(); + SmallVector producers; + for (Operation *target : state.getPayloadOps(getTarget())) { + Operation *producer = + target->getNumOperands() <= operandNumber + ? nullptr + : target->getOperand(operandNumber).getDefiningOp(); + if (!producer) { + DiagnosedSilenceableFailure diag = + emitSilenceableError() + << "could not find a producer for operand number: " << operandNumber + << " of " << *target; + diag.attachNote(target->getLoc()) << "target op"; + results.set(getResult().cast(), + SmallVector{}); + return diag; + } + producers.push_back(producer); + } + results.set(getResult().cast(), producers); + return DiagnosedSilenceableFailure::success(); +} + //===----------------------------------------------------------------------===// // MergeHandlesOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir --- a/mlir/test/Dialect/Transform/test-interpreter.mlir +++ b/mlir/test/Dialect/Transform/test-interpreter.mlir @@ -727,3 +727,36 @@ transform.test_print_remark_at_operand %results, "transform applied" } } + +// ----- + +func.func @get_parent_for_op_no_loop(%arg0: index, %arg1: index) { + // expected-remark @below {{found muli}} + %0 = arith.muli %arg0, %arg1 : index + arith.addi %0, %arg1 : index + return +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %addi = transform.structured.match ops{["arith.addi"]} in %arg1 + %muli = get_producer_of_operand %addi[0] + transform.test_print_remark_at_operand %muli, "found muli" +} + +// ----- + +func.func @get_parent_for_op_no_loop(%arg0: index, %arg1: index) { + // expected-note @below {{target op}} + %0 = arith.muli %arg0, %arg1 : index + return +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %muli = transform.structured.match ops{["arith.muli"]} in %arg1 + // expected-error @below {{could not find a producer for operand number: 0 of}} + %bbarg = get_producer_of_operand %muli[0] + +} +