diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td @@ -210,6 +210,34 @@ let hasFolder = 1; } +def SplitHandlesOp : TransformDialectOp<"split_handles", + [FunctionalStyleTransformOpTrait, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]> { + let summary = "Splits handles from a union of payload ops to a list"; + let description = [{ + Creates `num_result_handles` transform IR handles extracted from the + `handle` operand. The resulting Payload IR operation handles are listed + in the same order as the operations appear in the source `handle`. + This is useful for ensuring a statically known number of operations are + tracked by the source `handle` and to extract them into individual handles + that can be further manipulated in isolation. + + This operation succeeds and returns `num_result_handles` if the statically + specified `num_result_handles` corresponds to the dynamic number of + operations contained in the source `handle`. Otherwise it silently fails. + }]; + + let arguments = (ins PDL_Operation:$handle, + I64Attr:$num_result_handles); + let results = (outs Variadic:$results); + let assemblyFormat = [{ + $handle `in` `[` $num_result_handles `]` + custom(type($results), ref($num_result_handles)) + attr-dict + }]; +} + def PDLMatchOp : TransformDialectOp<"pdl_match", [DeclareOpInterfaceMethods]> { let summary = "Finds ops that match the named PDL pattern"; diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -590,9 +590,11 @@ getOps()->getAsValueRange().end()); ArrayRef payloadOps = state.getPayloadOps(getTarget()); - if (payloadOps.size() != 1) + if (payloadOps.size() != 1) { + results.set(getResult().cast(), {}); return DiagnosedSilenceableFailure( this->emitOpError("requires exactly one target handle")); + } SmallVector res; auto matchFun = [&](Operation *op) { @@ -877,8 +879,11 @@ } return OpFoldResult(op->getResult(0)); })); - if (!diag.succeeded()) + if (diag.isSilenceableFailure()) { + results.set(getFirst().cast(), {}); + results.set(getSecond().cast(), {}); return diag; + } if (splitPoints.size() != payload.size()) { emitError() << "expected the dynamic split point handle to point to as " @@ -900,6 +905,8 @@ if (!linalgOp) { auto diag = emitSilenceableError() << "only applies to structured ops"; diag.attachNote(target->getLoc()) << "target op"; + results.set(getFirst().cast(), {}); + results.set(getSecond().cast(), {}); return diag; } @@ -907,6 +914,8 @@ auto diag = emitSilenceableError() << "dimension " << getDimension() << " does not exist in target op"; diag.attachNote(target->getLoc()) << "target op"; + results.set(getFirst().cast(), {}); + results.set(getSecond().cast(), {}); return diag; } diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp --- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp +++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp @@ -47,6 +47,7 @@ << scf::ForOp::getOperationName() << "' parent"; diag.attachNote(target->getLoc()) << "target op"; + results.set(getResult().cast(), {}); return diag; } current = loop; @@ -100,6 +101,7 @@ DiagnosedSilenceableFailure diag = emitSilenceableError() << "failed to outline"; diag.attachNote(target->getLoc()) << "target op"; + results.set(getTransformed().cast(), {}); return diag; } func::CallOp call; diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp @@ -225,8 +225,11 @@ } transform::TransformResults results(transform->getNumResults()); + // Compute the result but do not short-circuit the silenceable failure case as + // we still want the handles to propagate properly so the "suppress" mode can + // proceed on a best effort basis. DiagnosedSilenceableFailure result(transform.apply(results, *this)); - if (!result.succeeded()) + if (result.isDefiniteFailure()) return result; // Remove the mapping for the operand if it is consumed by the operation. This @@ -258,7 +261,7 @@ DBGS() << "Top-level payload:\n"; getTopLevel()->print(llvm::dbgs()); }); - return DiagnosedSilenceableFailure::success(); + return result; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -23,6 +23,7 @@ using namespace mlir; +/// Custom parser for ReplicateOp. static ParseResult parsePDLOpTypedResults( OpAsmParser &parser, SmallVectorImpl &types, const SmallVectorImpl &handles) { @@ -30,9 +31,23 @@ return success(); } +/// Custom printer for ReplicateOp. static void printPDLOpTypedResults(OpAsmPrinter &, Operation *, TypeRange, ValueRange) {} +/// Custom parser for SplitHandlesOp. +static ParseResult parseStaticNumPDLResults(OpAsmParser &parser, + SmallVectorImpl &types, + IntegerAttr numHandlesAttr) { + types.resize(numHandlesAttr.getInt(), + pdl::OperationType::get(parser.getContext())); + return success(); +} + +/// Custom printer for SplitHandlesOp. +static void printStaticNumPDLResults(OpAsmPrinter &, Operation *, TypeRange, + IntegerAttr) {} + #define GET_OP_CLASSES #include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc" @@ -452,6 +467,46 @@ return getHandles().front(); } +//===----------------------------------------------------------------------===// +// SplitHandlesOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure +transform::SplitHandlesOp::apply(transform::TransformResults &results, + transform::TransformState &state) { + int64_t numResultHandles = + getHandle() ? state.getPayloadOps(getHandle()).size() : 0; + int64_t expectedNumResultHandles = getNumResultHandles(); + if (numResultHandles != expectedNumResultHandles) { + // Failing case needs to propagate gracefully for both suppress and + // propagate modes. + for (int64_t idx = 0; idx < expectedNumResultHandles; ++idx) + results.set(getResults()[idx].cast(), {}); + // Empty input handle corner case: always propagates empty handles in both + // suppress and propagate modes. + if (numResultHandles == 0) + return DiagnosedSilenceableFailure::success(); + // If the input handle was not empty and the number of result handles does + // not match, this is a legit silenceable error. + return emitSilenceableError() + << getHandle() << " expected to contain " << expectedNumResultHandles + << " operation handles but it only contains " << numResultHandles + << " handles"; + } + // Normal successful case. + for (auto en : llvm::enumerate(state.getPayloadOps(getHandle()))) + results.set(getResults()[en.index()].cast(), en.value()); + return DiagnosedSilenceableFailure::success(); +} + +void transform::SplitHandlesOp::getEffects( + SmallVectorImpl &effects) { + consumesHandle(getHandle(), effects); + producesHandle(getResults(), effects); + // There are no effects on the Payload IR as this is only a handle + // manipulation. +} + //===----------------------------------------------------------------------===// // PDLMatchOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir --- a/mlir/test/Dialect/Transform/test-interpreter.mlir +++ b/mlir/test/Dialect/Transform/test-interpreter.mlir @@ -761,3 +761,42 @@ } +// ----- + +func.func @split_handles(%a: index, %b: index, %c: index) { + %0 = arith.muli %a, %b : index + %1 = arith.muli %a, %c : index + return +} + +transform.sequence failures(propagate) { +^bb1(%fun: !pdl.operation): + %muli = transform.structured.match ops{["arith.muli"]} in %fun + %h:2 = split_handles %muli in [2] + // expected-remark @below {{1}} + transform.test_print_number_of_associated_payload_ir_ops %h#0 + %muli_2 = transform.structured.match ops{["arith.muli"]} in %fun + // expected-error @below {{expected to contain 3 operation handles but it only contains 2 handles}} + %h_2:3 = split_handles %muli_2 in [3] +} + +// ----- + +func.func @split_handles(%a: index, %b: index, %c: index) { + %0 = arith.muli %a, %b : index + %1 = arith.muli %a, %c : index + return +} + +transform.sequence failures(suppress) { +^bb1(%fun: !pdl.operation): + %muli = transform.structured.match ops{["arith.muli"]} in %fun + %h:2 = split_handles %muli in [2] + // expected-remark @below {{1}} + transform.test_print_number_of_associated_payload_ir_ops %h#0 + %muli_2 = transform.structured.match ops{["arith.muli"]} in %fun + // Silenceable failure and all handles are now empty. + %h_2:3 = split_handles %muli_2 in [3] + // expected-remark @below {{0}} + transform.test_print_number_of_associated_payload_ir_ops %h_2#0 +} diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp @@ -292,6 +292,8 @@ DiagnosedSilenceableFailure mlir::test::TestPrintNumberOfAssociatedPayloadIROps::apply( transform::TransformResults &results, transform::TransformState &state) { + if (!getHandle()) + emitRemark() << 0; emitRemark() << state.getPayloadOps(getHandle()).size(); return DiagnosedSilenceableFailure::success(); }