diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td @@ -210,6 +210,34 @@ let hasFolder = 1; } +def SplitHandlesOp : TransformDialectOp<"split_handles", + [FunctionalStyleTransformOpTrait, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]> { + let summary = "Splits handles from a union of payload ops to a list"; + let description = [{ + Creates `num_result_handles` transform IR handles extracted from the + `handle` operand. The resulting Payload IR operation handles are listed + in the same order as the operations appear in the source `handle`. + This is useful for ensuring a statically known number of operations are + tracked by the source `handle` and to extract them into individual handles + that can be further manipulated in isolation. + + This operation succeeds and returns `num_result_handles` if the statically + specified `num_result_handles` corresponds to the dynamic number of + operations contained in the source `handle`. Otherwise it silently fails. + }]; + + let arguments = (ins PDL_Operation:$handle, + I64Attr:$num_result_handles); + let results = (outs Variadic:$results); + let assemblyFormat = [{ + $handle `in` `[` $num_result_handles `]` + custom(type($results), ref($num_result_handles)) + attr-dict + }]; +} + def PDLMatchOp : TransformDialectOp<"pdl_match", [DeclareOpInterfaceMethods]> { let summary = "Finds ops that match the named PDL pattern"; diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -23,6 +23,7 @@ using namespace mlir; +/// Custom parser for ReplicateOp. static ParseResult parsePDLOpTypedResults( OpAsmParser &parser, SmallVectorImpl &types, const SmallVectorImpl &handles) { @@ -30,9 +31,23 @@ return success(); } +/// Custom printer for ReplicateOp. static void printPDLOpTypedResults(OpAsmPrinter &, Operation *, TypeRange, ValueRange) {} +/// Custom parser for SplitHandlesOp. +static ParseResult parseStaticNumPDLResults(OpAsmParser &parser, + SmallVectorImpl &types, + IntegerAttr numHandlesAttr) { + types.resize(numHandlesAttr.getInt(), + pdl::OperationType::get(parser.getContext())); + return success(); +} + +/// Custom printer for SplitHandlesOp. +static void printStaticNumPDLResults(OpAsmPrinter &, Operation *, TypeRange, + IntegerAttr) {} + #define GET_OP_CLASSES #include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc" @@ -452,6 +467,39 @@ return getHandles().front(); } +//===----------------------------------------------------------------------===// +// SplitHandlesOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure +transform::SplitHandlesOp::apply(transform::TransformResults &results, + transform::TransformState &state) { + int64_t numResultHandles = state.getPayloadOps(getHandle()).size(); + int64_t expectedNumResultHandles = getNumResultHandles(); + // Propagate mode. + if (numResultHandles != expectedNumResultHandles) { + for (int64_t idx = 0; idx < expectedNumResultHandles; ++idx) + results.set(getResults()[idx].cast(), {}); + if (numResultHandles == 0) + return DiagnosedSilenceableFailure::success(); + return emitSilenceableError() + << getHandle() << " expected to contain " << expectedNumResultHandles + << " operation handles but it only contains " << numResultHandles + << " handles"; + } + for (auto en : llvm::enumerate(state.getPayloadOps(getHandle()))) + results.set(getResults()[en.index()].cast(), en.value()); + return DiagnosedSilenceableFailure::success(); +} + +void transform::SplitHandlesOp::getEffects( + SmallVectorImpl &effects) { + consumesHandle(getHandle(), effects); + producesHandle(getResults(), effects); + // There are no effects on the Payload IR as this is only a handle + // manipulation. +} + //===----------------------------------------------------------------------===// // PDLMatchOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir --- a/mlir/test/Dialect/Transform/test-interpreter.mlir +++ b/mlir/test/Dialect/Transform/test-interpreter.mlir @@ -761,3 +761,24 @@ } +// ----- + +func.func @split_handles(%a: index, %b: index, %c: index) { + %0 = arith.muli %a, %b : index + %1 = arith.muli %a, %c : index + return +} + +transform.sequence failures(suppress) { +^bb1(%fun: !pdl.operation): + %muli = transform.structured.match ops{["arith.muli"]} in %fun + %h:2 = split_handles %muli in [2] + // expected-remark @below {{1}} + transform.test_print_number_of_associated_payload_ir_ops %h#0 + // Both form succeed, the second form has all empty handles. + %muli_2 = transform.structured.match ops{["arith.muli"]} in %fun + // expected-error @below {{expected to contain 3 operation handles but it only contains 2 handles}} + %h_2:3 = split_handles %muli_2 in [3] + // expected-remark @below {{0}} + transform.test_print_number_of_associated_payload_ir_ops %h_2#0 +} diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp @@ -292,6 +292,8 @@ DiagnosedSilenceableFailure mlir::test::TestPrintNumberOfAssociatedPayloadIROps::apply( transform::TransformResults &results, transform::TransformState &state) { + if (!getHandle()) + emitRemark() << 0; emitRemark() << state.getPayloadOps(getHandle()).size(); return DiagnosedSilenceableFailure::success(); }