diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td @@ -230,6 +230,24 @@ "functional-type(operands, results)"; } +def GetResultOp : TransformDialectOp<"get_result", + [DeclareOpInterfaceMethods, + NavigationTransformOpTrait, MemoryEffectsOpInterface]> { + let summary = "Get handle to the a result of the targeted op"; + let description = [{ + The handle defined by this Transform op corresponds to the OpResult with + `result_number` that is defined by the given `target` operation. This + transform fails silently if the targeted operation does not have enough + results. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target, + I64Attr:$result_number); + let results = (outs TransformValueHandleTypeInterface:$result); + let assemblyFormat = "$target `[` $result_number `]` attr-dict `:` " + "functional-type(operands, results)"; +} + def MergeHandlesOp : TransformDialectOp<"merge_handles", [DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp @@ -722,6 +722,8 @@ if (opResult.getType().isa()) results.setParams(opResult, {}); + else if (opResult.getType().isa()) + results.setValues(opResult, {}); else results.set(opResult, {}); } @@ -831,7 +833,7 @@ void transform::TransformResults::setValues(OpResult handle, ValueRange values) { int64_t position = handle.getResultNumber(); - assert(position < static_cast(values.size()) && + assert(position < static_cast(this->values.size()) && "setting values for a non-existent handle"); assert(this->values[position].data() == nullptr && "values already set"); assert(operations[position].data() == nullptr && @@ -861,8 +863,8 @@ ArrayRef transform::TransformResults::getValues(unsigned resultNumber) const { - assert(resultNumber < params.size() && - "querying params for a non-existent handle"); + assert(resultNumber < values.size() && + "querying values for a non-existent handle"); assert(values[resultNumber].data() != nullptr && "querying unset values (ops or params expected?)"); return values[resultNumber]; diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -474,6 +474,28 @@ return DiagnosedSilenceableFailure::success(); } +//===----------------------------------------------------------------------===// +// GetResultOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure +transform::GetResultOp::apply(transform::TransformResults &results, + transform::TransformState &state) { + int64_t resultNumber = getResultNumber(); + SmallVector opResults; + for (Operation *target : state.getPayloadOps(getTarget())) { + if (resultNumber >= target->getNumResults()) { + DiagnosedSilenceableFailure diag = + emitSilenceableError() << "targeted op does not have enough results"; + diag.attachNote(target->getLoc()) << "target op"; + return diag; + } + opResults.push_back(target->getOpResult(resultNumber)); + } + results.setValues(getResult().cast(), opResults); + return DiagnosedSilenceableFailure::success(); +} + //===----------------------------------------------------------------------===// // MergeHandlesOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir --- a/mlir/test/Dialect/Transform/test-interpreter.mlir +++ b/mlir/test/Dialect/Transform/test-interpreter.mlir @@ -1190,3 +1190,35 @@ // expected-error @below {{unexpectedly consumed a value that is not a handle as operand #0}} test_consume_operand %0 : !transform.test_dialect_param } + +// ----- + +func.func @get_result_of_op(%arg0: index, %arg1: index) -> index { + // expected-remark @below {{addi result}} + // expected-note @below {{value handle points to an op result #0}} + %r = arith.addi %arg0, %arg1 : index + return %r : index +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %addi = transform.structured.match ops{["arith.addi"]} in %arg1 : (!pdl.operation) -> !pdl.operation + %result = transform.get_result %addi[0] : (!pdl.operation) -> !transform.any_value + transform.test_print_remark_at_operand_value %result, "addi result" : !transform.any_value +} + +// ----- + +func.func @get_out_of_bounds_result_of_op(%arg0: index, %arg1: index) -> index { + // expected-note @below {{target op}} + %r = arith.addi %arg0, %arg1 : index + return %r : index +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %addi = transform.structured.match ops{["arith.addi"]} in %arg1 : (!pdl.operation) -> !pdl.operation + // expected-error @below {{targeted op does not have enough results}} + %result = transform.get_result %addi[1] : (!pdl.operation) -> !transform.any_value + transform.test_print_remark_at_operand_value %result, "addi result" : !transform.any_value +}