diff --git a/mlir/docs/Tutorials/transform/Ch1.md b/mlir/docs/Tutorials/transform/Ch1.md --- a/mlir/docs/Tutorials/transform/Ch1.md +++ b/mlir/docs/Tutorials/transform/Ch1.md @@ -362,3 +362,16 @@ Note that the “add” elementwise operation is indicated as payload ancestor because it was used to produce the tile loop, and the loop therefore has its location. Finally, we would like to replace the call to the outlined function with a call to the microkernel. Unfortunately, the Transform dialect doesn’t have support for this transformation (and cannot have if the call is rewritten to a custom, out-of-tree operation). Therefore, we need to define new transform operations. The next chapters will describe how this can be done. + +## Tracking IR Modifications + +The transform dialect automatically tracks all IR changes that are made as part +of transform ops. (Implementations must use the provided rewriter to modify IR.) +If a payload op is erased, it is automatically removed from all handles that it +is currently associated with. If a payload op is replaced, the transform dialect +tries to find the replacement op and updates all handles accordingly. If a +multi-result op is replaced with values that are defined by multiple ops, or if +an op is replaced with an op of a different type, an error is produced. This is +because it is unclear whether the direct replacements actually represent the +computation of the original op. There are ways to customize this behavior. More +details can be found at the documentation of `transform::TrackingListener`. diff --git a/mlir/docs/Tutorials/transform/Ch2.md b/mlir/docs/Tutorials/transform/Ch2.md --- a/mlir/docs/Tutorials/transform/Ch2.md +++ b/mlir/docs/Tutorials/transform/Ch2.md @@ -189,8 +189,13 @@ } ``` -To finalize the definition of the transform operation, we need to implement the interface methods. The `TransformOpInterface` currently requires only one method – `apply` – that performs the actual transformation. It is a good practice to limit the body of the method to manipulation of the Transform dialect constructs and have the actual transformation implemented as a standalone function so it can be used from other places in the code. - +To finalize the definition of the transform operation, we need to implement the +interface methods. The `TransformOpInterface` currently requires only one method +– `apply` – that performs the actual transformation. It is a good practice to +limit the body of the method to manipulation of the Transform dialect constructs +and have the actual transformation implemented as a standalone function so it +can be used from other places in the code. Similar to rewrite patterns, all IR +must be modified with the provided rewriter. ```c++ // In MyExtension.cpp @@ -198,30 +203,32 @@ // Implementation of our transform dialect operation. // This operation returns a tri-state result that can be one of: // - success when the transformation succeeded; -// - definite failure when the transformation failed in such a way that following -// transformations are impossible or undesirable, typically it could have left payload -// IR in an invalid state; it is expected that a diagnostic is emitted immediately -// before returning the definite error; -// - silenceable failure when the transformation failed but following transformations -// are still applicable, typically this means a precondition for the transformation is -// not satisfied and the payload IR has not been modified. -// The silenceable failure additionally carries a Diagnostic that can be emitted to the -// user. -::mlir::DiagnosedSilenceableFailure ChangeCallTargetOp::apply( - // The list of payload IR entities that will be associated with the transform IR - // values defined by this transform operation. In this case, it can remain empty as - // there are no results. +// - definite failure when the transformation failed in such a way that +// following transformations are impossible or undesirable, typically it could +// have left payload IR in an invalid state; it is expected that a diagnostic +// is emitted immediately before returning the definite error; +// - silenceable failure when the transformation failed but following +// transformations are still applicable, typically this means a precondition +// for the transformation is not satisfied and the payload IR has not been +// modified. The silenceable failure additionally carries a Diagnostic that +// can be emitted to the user. +::mlir::DiagnosedSilenceableFailure mlir::transform::ChangeCallTargetOp::apply( + // The rewriter that should be used when modifying IR. + ::mlir::transform::TransformRewriter &rewriter, + // The list of payload IR entities that will be associated with the + // transform IR values defined by this transform operation. In this case, it + // can remain empty as there are no results. ::mlir::transform::TransformResults &results, - // The transform application state. This object can be used to query the current - // associations between transform IR values and payload IR entities. It can also - // carry additional user-defined state. + // The transform application state. This object can be used to query the + // current associations between transform IR values and payload IR entities. + // It can also carry additional user-defined state. ::mlir::transform::TransformState &state) { - // First, we need to obtain the list of payload operations that are associated with + // First, we need to obtain the list of payload operations that are associated with // the operand handle. auto payload = state.getPayloadOps(getCall()); - - // Then, we iterate over the list of operands and call the actual IR-mutating + + // Then, we iterate over the list of operands and call the actual IR-mutating // function. We also check the preconditions here. for (Operation *payloadOp : payload) { auto call = dyn_cast<::mlir::func::CallOp>(payloadOp); @@ -231,7 +238,7 @@ diag.attachNote(payloadOp->getLoc()) << "offending payload"; return diag; } - + updateCallee(call, getNewTarget()); } diff --git a/mlir/docs/Tutorials/transform/Ch3.md b/mlir/docs/Tutorials/transform/Ch3.md --- a/mlir/docs/Tutorials/transform/Ch3.md +++ b/mlir/docs/Tutorials/transform/Ch3.md @@ -10,26 +10,26 @@ // Define the new operation. By convention, prefix its name with the name of the dialect extension, "my.". The full operation name will be further prefixed with "transform.". def ChangeCallTargetOp : Op]> { - // Provide a brief and a full description. It is recommended that the latter describes + // Provide a brief and a full description. It is recommended that the latter describes // the effects on the operands and how the operation processes various failure modes. let summary = "Changes the callee of a call operation to the specified one"; let description = [{ - For each `func.call` payload operation associated with the handle, changes its + For each `func.call` payload operation associated with the handle, changes its callee to be the symbol whose name is provided as an attribute to this operation. - Generates a silenceable failure if the operand is associated with payload operations + Generates a silenceable failure if the operand is associated with payload operations that are not `func.call`. Only reads the operand. }]; - // The arguments include the handle to the payload operations and the attribute that - // specifies the new callee. The handle must implement TransformHandleTypeInterface. - // We use a string attribute as the symbol may not exist in the transform IR so the - // verification may fail. + // The arguments include the handle to the payload operations and the attribute that + // specifies the new callee. The handle must implement TransformHandleTypeInterface. + // We use a string attribute as the symbol may not exist in the transform IR so the + // verification may fail. let arguments = (ins Transform_ConcreteOpType<"func.call">:$call, StrAttr:$new_target); @@ -43,6 +43,7 @@ // Declare the function implementing the interface for a single payload operation. let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::func::CallOp call, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -54,7 +55,8 @@ ```c++ ::mlir::DiagnosedSilenceableFailure ChangeCallTargetOp::applyToOne( - ::mlir::func::CallOp call,, + ::mlir::transform::TransformRewriter &rewriter, + ::mlir::func::CallOp call, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state) { // Call the actual transformation function. @@ -150,14 +152,13 @@ As an exercise, let us modify the rewriting operation to consume the operand. This would be necessary, for example, if the transformation were to rewrite the `func.call` operation into a custom operation `my.mm4`. Since the operand handle is now consumed, the operation can return a new handle to the newly produced payload operation. Otherwise, the ODS definition of the transform operation remains unchanged. - ```tablegen // In MyExtension.td. // Define another transform operation. def CallToOp : Op]> { @@ -166,7 +167,7 @@ // The argument is the handle to the payload operations. let arguments = (ins CallOpInterfaceHandle:$call); - // The result is the handle to the payload operations produced during the + // The result is the handle to the payload operations produced during the // transformation. let results = (outs TransformHandleTypeInterface:$transformed); @@ -176,6 +177,7 @@ // Declare the function implementing the interface for a single payload operation. let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::CallOpInterface call, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -189,23 +191,24 @@ // In MyExtension.cpp. ::mlir::DiagnosedSilenceableFailure CallToOp::applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::CallOpInterface call, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state) { // Call the actual rewrite. Operation *rewritten = rewriteToOp(call); - // Report an error if the rewriter produced a null pointer. Note that it may have + // Report an error if the rewriter produced a null pointer. Note that it may have // irreversibly modified the payload IR, so we produce a definite failure. if (!rewritten) { return emitDefiniteError() << "failed to rewrite call to operation"; } - // On success, push the resulting operation into the result list. The list is expected - // to contain exactly one entity per result and per application. The handles will be + // On success, push the resulting operation into the result list. The list is expected + // to contain exactly one entity per result and per application. The handles will be // associated with lists of the respective values produced by each application. results.push_back(rewritten); - + // If everything is fine, return success. return DiagnosedSilenceableFailure::success(); } @@ -228,6 +231,7 @@ ```c++ ::mlir::DiagnosedSilenceableFailure SomeOtherOp::apply( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::transform::TransformResults &results, ::mlir::transform::TransformState &state) { // ... @@ -263,7 +267,7 @@ // The argument is the handle to the payload operations. let arguments = (ins CallOpInterfaceHandle:$call); - // The result is the handle to the payload operations produced during the + // The result is the handle to the payload operations produced during the // transformation. let results = (outs TransformHandleTypeInterface:$transformed); @@ -273,6 +277,7 @@ // Declare the function implementing the interface for a single payload operation. let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::CallOpInterface call, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); diff --git a/mlir/examples/transform/Ch2/lib/MyExtension.cpp b/mlir/examples/transform/Ch2/lib/MyExtension.cpp --- a/mlir/examples/transform/Ch2/lib/MyExtension.cpp +++ b/mlir/examples/transform/Ch2/lib/MyExtension.cpp @@ -74,17 +74,17 @@ // This operation returns a tri-state result that can be one of: // - success when the transformation succeeded; // - definite failure when the transformation failed in such a way that -// following -// transformations are impossible or undesirable, typically it could have left -// payload IR in an invalid state; it is expected that a diagnostic is emitted -// immediately before returning the definite error; +// following transformations are impossible or undesirable, typically it could +// have left payload IR in an invalid state; it is expected that a diagnostic +// is emitted immediately before returning the definite error; // - silenceable failure when the transformation failed but following -// transformations -// are still applicable, typically this means a precondition for the -// transformation is not satisfied and the payload IR has not been modified. -// The silenceable failure additionally carries a Diagnostic that can be emitted -// to the user. +// transformations are still applicable, typically this means a precondition +// for the transformation is not satisfied and the payload IR has not been +// modified. The silenceable failure additionally carries a Diagnostic that +// can be emitted to the user. ::mlir::DiagnosedSilenceableFailure mlir::transform::ChangeCallTargetOp::apply( + // The rewriter that should be used when modifying IR. + ::mlir::transform::TransformRewriter &rewriter, // The list of payload IR entities that will be associated with the // transform IR values defined by this transform operation. In this case, it // can remain empty as there are no results. diff --git a/mlir/examples/transform/Ch3/include/MyExtension.td b/mlir/examples/transform/Ch3/include/MyExtension.td --- a/mlir/examples/transform/Ch3/include/MyExtension.td +++ b/mlir/examples/transform/Ch3/include/MyExtension.td @@ -60,6 +60,7 @@ // Declare the function implementing the interface for a single payload operation. let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::func::CallOp call, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -89,6 +90,7 @@ // Declare the function implementing the interface for a single payload operation. let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::CallOpInterface call, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); diff --git a/mlir/examples/transform/Ch3/lib/MyExtension.cpp b/mlir/examples/transform/Ch3/lib/MyExtension.cpp --- a/mlir/examples/transform/Ch3/lib/MyExtension.cpp +++ b/mlir/examples/transform/Ch3/lib/MyExtension.cpp @@ -113,6 +113,8 @@ // to the user. ::mlir::DiagnosedSilenceableFailure mlir::transform::ChangeCallTargetOp::applyToOne( + // The rewriter that should be used when modifying IR. + ::mlir::transform::TransformRewriter &rewriter, // The single payload operation to which the transformation is applied. ::mlir::func::CallOp call, // The payload IR entities that will be appended to lists associated with @@ -146,27 +148,27 @@ // CallToOp //===---------------------------------------------------------------------===// -static mlir::Operation *replaceCallWithOp(mlir::CallOpInterface call) { +static mlir::Operation *replaceCallWithOp(mlir::RewriterBase &rewriter, + mlir::CallOpInterface call) { // Construct an operation from an unregistered dialect. This is discouraged // and is only used here for brevity of the overall example. mlir::OperationState state(call.getLoc(), "my.mm4"); state.types.assign(call->result_type_begin(), call->result_type_end()); state.operands.assign(call->operand_begin(), call->operand_end()); - mlir::OpBuilder builder(call); - mlir::Operation *replacement = builder.create(state); - call->replaceAllUsesWith(replacement->getResults()); - call->erase(); + mlir::Operation *replacement = rewriter.create(state); + rewriter.replaceOp(call, replacement->getResults()); return replacement; } // See above for the signature description. mlir::DiagnosedSilenceableFailure mlir::transform::CallToOp::applyToOne( - mlir::CallOpInterface call, mlir::transform::ApplyToEachResultList &results, + mlir::transform::TransformRewriter &rewriter, mlir::CallOpInterface call, + mlir::transform::ApplyToEachResultList &results, mlir::transform::TransformState &state) { // Dispatch to the actual transformation. - Operation *replacement = replaceCallWithOp(call); + Operation *replacement = replaceCallWithOp(rewriter, call); // Associate the payload operation produced by the rewrite with the result // handle of this transform operation. diff --git a/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td b/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td --- a/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td +++ b/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td @@ -152,6 +152,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::tensor::EmptyOp target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); diff --git a/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td b/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td --- a/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td +++ b/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td @@ -126,6 +126,7 @@ }]; let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::Operation *target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -189,6 +190,7 @@ }]; let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::Operation *target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -80,7 +80,8 @@ def BufferizeToAllocationOp : Op, - DeclareOpInterfaceMethods]> { + DeclareOpInterfaceMethods, + ReportTrackingListenerFailuresOpTrait]> { let description = [{ This transform materializes an allocation for the targeted tensor value. It replaces all original uses of the target with the newly allocated buffer, @@ -133,7 +134,8 @@ [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface, TransformOpInterface, - TransformEachOpTrait]> { + TransformEachOpTrait, + ReportTrackingListenerFailuresOpTrait]> { let description = [{ Decomposes named complex operations, such as higher-dimensional (depthwise) convolutions, into combinations of lower-dimensional equivalents @@ -155,6 +157,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::linalg::LinalgOp target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -167,7 +170,8 @@ def FuseOp : Op]> { + DeclareOpInterfaceMethods, + ReportTrackingListenerFailuresOpTrait]> { let description = [{ Tiles the operations pointed to by the target handle and fuses their producers greedily using the options provided as attributes. @@ -192,7 +196,8 @@ Op, - DeclareOpInterfaceMethods]> { + DeclareOpInterfaceMethods, + ReportTrackingListenerFailuresOpTrait]> { let summary = "Fuse a producer into a containing operation."; let description = [{ @@ -247,7 +252,8 @@ def GeneralizeOp : Op { + TransformOpInterface, TransformEachOpTrait, + ReportTrackingListenerFailuresOpTrait]> { let description = [{ Transforms a named structured operation into the generic form with the explicit attached region. @@ -270,6 +276,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::linalg::LinalgOp target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -282,7 +289,8 @@ def InterchangeOp : Op { + TransformOpInterface, TransformEachOpTrait, + ReportTrackingListenerFailuresOpTrait]> { let description = [{ Interchanges the iterators of the operations pointed to by the target handle using the iterator interchange attribute. @@ -313,6 +321,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::linalg::GenericOp target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -326,7 +335,8 @@ FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface, TransformEachOpTrait, - TransformOpInterface]> { + TransformOpInterface, + ReportTrackingListenerFailuresOpTrait]> { let description = [{ Rewrite a tensor.pack into tensor.pad + tensor.expand_shape + linalg.transpose. @@ -350,6 +360,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::tensor::PackOp target, ::mlir::transform::ApplyToEachResultList &transformResults, ::mlir::transform::TransformState &state); @@ -363,7 +374,8 @@ FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface, TransformEachOpTrait, - TransformOpInterface]> { + TransformOpInterface, + ReportTrackingListenerFailuresOpTrait]> { let description = [{ Lower a tensor.unpack into empty + linalg.transpose + tensor.collapse_shape + tensor.extract_slice. @@ -389,6 +401,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::tensor::UnPackOp target, ::mlir::transform::ApplyToEachResultList &transformResults, ::mlir::transform::TransformState &state); @@ -461,7 +474,8 @@ def MultiTileSizesOp : Op, - TransformOpInterface, TransformEachOpTrait]> { + TransformOpInterface, TransformEachOpTrait, + ReportTrackingListenerFailuresOpTrait]> { let description = [{ Emits the IR computing the tile sizes `s1` and `s2` such that: @@ -531,6 +545,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::linalg::LinalgOp target, ::mlir::transform::ApplyToEachResultList &results, TransformState &state); @@ -543,7 +558,8 @@ def PackOp : Op, - DeclareOpInterfaceMethods]> { + DeclareOpInterfaceMethods, + ReportTrackingListenerFailuresOpTrait]> { let description = [{ Pack a LinalgOp by applying a data tiling transformation on the op and packing the operands according to the `packed_sizes` specification. @@ -632,7 +648,8 @@ //===----------------------------------------------------------------------===// def PackGreedilyOp : Op, - DeclareOpInterfaceMethods]> { + DeclareOpInterfaceMethods, + ReportTrackingListenerFailuresOpTrait]> { let description = [{ Target a Linalg op and rewrite it into packed LinalgOp form by trying to infer whether a known suboperation is embedded @@ -740,7 +757,8 @@ def PackTransposeOp : Op]> { + DeclareOpInterfaceMethods, + ReportTrackingListenerFailuresOpTrait]> { let description = [{ Apply a transposition to a single `tensor.pack` (resp. `tensor.unpack`) and update the `linalg.generic` op that consumes (resp. produces) the operation. @@ -803,7 +821,8 @@ def PadOp : Op { + TransformOpInterface, TransformEachOpTrait, + ReportTrackingListenerFailuresOpTrait]> { let description = [{ Pads the operations pointed to by the target handle using the options provides as operation attributes. @@ -837,6 +856,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::linalg::LinalgOp target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -851,7 +871,8 @@ Op, - DeclareOpInterfaceMethods]> { + DeclareOpInterfaceMethods, + ReportTrackingListenerFailuresOpTrait]> { let description = [{ Helper transform used to hoist a tensor.pad target operation. This operation creates the packing loop nest required by the hoist_pad operation and makes @@ -931,6 +952,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::tensor::PadOp, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -944,7 +966,8 @@ def PromoteOp : Op { + TransformOpInterface, TransformEachOpTrait, + ReportTrackingListenerFailuresOpTrait]> { let description = [{ Promotes the specified operands of the target into a separate memory buffer. @@ -978,6 +1001,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::linalg::LinalgOp target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -990,7 +1014,8 @@ def ReplaceOp : Op, - DeclareOpInterfaceMethods] # GraphRegionNoTerminator.traits> { + DeclareOpInterfaceMethods, + ReportTrackingListenerFailuresOpTrait] # GraphRegionNoTerminator.traits> { let description = [{ Replace all `target` payload ops with the single op that is contained in this op's region. All targets must have zero arguments and must be isolated @@ -1018,7 +1043,8 @@ def ScalarizeOp : Op { + TransformOpInterface, TransformEachOpTrait, + ReportTrackingListenerFailuresOpTrait]> { let description = [{ Indicates that ops of a specific kind in the given function should be scalarized (i.e. their dynamic dimensions tiled by 1). @@ -1051,6 +1077,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::linalg::LinalgOp target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -1066,7 +1093,8 @@ [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface, TransformOpInterface, - TransformEachOpTrait]> { + TransformEachOpTrait, + ReportTrackingListenerFailuresOpTrait]> { let description = [{ Rewrite a supported tensor operation that is not in destination-passing style into a form that is in destination-passing style. @@ -1098,6 +1126,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::Operation *target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -1110,7 +1139,8 @@ def SplitOp : Op, - DeclareOpInterfaceMethods]> { + DeclareOpInterfaceMethods, + ReportTrackingListenerFailuresOpTrait]> { let description = [{ Indicates that the given `target` op should be split into two complementary parts, which combined cover the entire iteration domain of the original op. @@ -1147,7 +1177,8 @@ def SplitReductionOp : Op { + TransformEachOpTrait, TransformOpInterface, + ReportTrackingListenerFailuresOpTrait]> { let description = [{ Indicates that the given `target` op should be transformed with the `splitReduction` transformation and split factor provided as attribute. @@ -1309,6 +1340,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::linalg::LinalgOp target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -1321,7 +1353,8 @@ def TileReductionUsingScfOp : Op { + TransformEachOpTrait, TransformOpInterface, + ReportTrackingListenerFailuresOpTrait]> { let description = [{ Indicates that the given `target` op should be transformed with the `tileReduction` transformation with the tile size provided as attribute. @@ -1413,6 +1446,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::linalg::LinalgOp target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -1426,7 +1460,8 @@ def TileReductionUsingForallOp : Op { + TransformEachOpTrait, TransformOpInterface, + ReportTrackingListenerFailuresOpTrait]> { let description = [{ Tile a PartialReductionOpInterface op to a tiled `scf.forall` doing partial reduction. @@ -1522,6 +1557,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::linalg::LinalgOp target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -1535,7 +1571,8 @@ def TileOp : Op, - DeclareOpInterfaceMethods]> { + DeclareOpInterfaceMethods, + ReportTrackingListenerFailuresOpTrait]> { let description = [{ Indicates that the given `target` op should be tiled with the given sizes. This transform generates a loop nest with a smaller ("tiled") target @@ -1616,7 +1653,7 @@ Op, - TransformOpInterface]> { + TransformOpInterface, ReportTrackingListenerFailuresOpTrait]> { let description = [{ Tile a TilingInterface op to a tiled `scf.forall`. @@ -1726,6 +1763,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure apply( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::transform::TransformResults &transformResults, ::mlir::transform::TransformState &state); @@ -1740,7 +1778,8 @@ def TileToScfForOp : Op, - DeclareOpInterfaceMethods]> { + DeclareOpInterfaceMethods, + ReportTrackingListenerFailuresOpTrait]> { let description = [{ Indicates that the given `target` op should be tiled with the given sizes. This transform generates a loop nest with a smaller ("tiled") target @@ -1807,7 +1846,8 @@ def VectorizeOp : Op { + TransformEachOpTrait, TransformOpInterface, + ReportTrackingListenerFailuresOpTrait]> { let description = [{ Indicates that the given `target` op all the ops it contains should be vectorized with the configuration specified by the attributes of this op. @@ -1862,6 +1902,7 @@ ]; let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::Operation *target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -1870,7 +1911,7 @@ def MaskedVectorizeOp : Op, - TransformOpInterface]> { + TransformOpInterface, ReportTrackingListenerFailuresOpTrait]> { let description = [{ Vectorize the target ops, which must be Linalg ops, with masked vectors of the specified size. @@ -1914,6 +1955,7 @@ let extraClassDeclaration = [{ // TODO: applyToOne. ::mlir::DiagnosedSilenceableFailure apply( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::transform::TransformResults &transformResults, ::mlir::transform::TransformState &state); @@ -1928,7 +1970,8 @@ def HoistRedundantVectorTransfersOp : Op { + TransformEachOpTrait, TransformOpInterface, + ReportTrackingListenerFailuresOpTrait]> { let description = [{ Hoist vector.transfer_read / vector.transfer_write pairs out of immediately enclosing scf::ForOp iteratively, if the following conditions are true: @@ -1958,7 +2001,8 @@ ]; let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( - ::mlir::func::FuncOp target, + ::mlir::transform::TransformRewriter &rewriter, + ::mlir::func::FuncOp target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); }]; @@ -1973,7 +2017,8 @@ [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface, TransformOpInterface, - TransformEachOpTrait]> { + TransformEachOpTrait, + ReportTrackingListenerFailuresOpTrait]> { let description = [{ Convert linalg.conv_2d_xxx into linalg.generic (for img2col packing) and linalg.matmul. @@ -2034,6 +2079,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::linalg::LinalgOp target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -2048,7 +2094,8 @@ Op, TransformEachOpTrait, - TransformOpInterface]> { + TransformOpInterface, + ReportTrackingListenerFailuresOpTrait]> { let description = [{ Hoists supported tensor subset extract/insert operation pairs out of immediately enclosing loop iteratively, if the following conditions @@ -2082,6 +2129,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::Operation *target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -2120,6 +2168,7 @@ ]; let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::Operation *target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); diff --git a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td --- a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td +++ b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td @@ -158,6 +158,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::Operation *target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td --- a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td +++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td @@ -126,6 +126,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::scf::ForOp target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -168,6 +169,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::scf::ForOp target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -201,6 +203,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::Operation *target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -228,6 +231,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::Operation *target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -269,6 +273,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::scf::IfOp ifOp, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); diff --git a/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td b/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td --- a/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td +++ b/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td @@ -138,6 +138,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::Operation *target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); diff --git a/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h --- a/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h +++ b/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h @@ -50,7 +50,8 @@ return success(); } - DiagnosedSilenceableFailure apply(TransformResults &results, + DiagnosedSilenceableFailure apply(TransformRewriter &rewriter, + TransformResults &results, TransformState &state) { Value operandHandle = cast(this->getOperation()).getOperandHandle(); auto payload = state.getPayloadOps(operandHandle); @@ -90,7 +91,8 @@ return success(); } - DiagnosedSilenceableFailure apply(TransformResults &results, + DiagnosedSilenceableFailure apply(TransformRewriter &rewriter, + TransformResults &results, TransformState &state) { Value operandHandle = cast(this->getOperation()).getOperandHandle(); ValueRange payload = state.getPayloadValues(operandHandle); diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td @@ -27,11 +27,16 @@ constexpr const static ::llvm::StringLiteral kWithNamedSequenceAttrName = "transform.with_named_sequence"; - /// Names of the attribute attachable to an operation so it can be + /// Name of the attribute attachable to an operation so it can be /// identified as root by the default interpreter pass. constexpr const static ::llvm::StringLiteral kTargetTagAttrName = "transform.target_tag"; + /// Name of the attribute attachable to an operation, indicating that + /// TrackingListener failures should be silenced. + constexpr const static ::llvm::StringLiteral + kSilenceTrackingFailuresAttrName = "transform.silence_tracking_failures"; + /// Names of the attributes indicating whether an argument of an external /// transform dialect symbol is consumed or only read. constexpr const static ::llvm::StringLiteral diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h @@ -18,11 +18,11 @@ #include "mlir/Support/LogicalResult.h" namespace mlir { - namespace transform { class TransformOpInterface; class TransformResults; +class TransformRewriter; class TransformState; using Param = Attribute; @@ -856,8 +856,7 @@ public TransformState::Extension { public: /// Create a new TrackingListener for usage in the specified transform op. - explicit TrackingListener(TransformState &state, TransformOpInterface op) - : TransformState::Extension(state), transformOp(op) {} + TrackingListener(TransformState &state, TransformOpInterface op); protected: /// Return a replacement payload op for the given op, which is going to be @@ -937,6 +936,9 @@ /// The transform op in which this TrackingListener is used. TransformOpInterface transformOp; + + /// The handles that are consumed by the transform op. + DenseSet consumedHandles; }; /// A specialized listener that keeps track of cases in which no replacement @@ -968,6 +970,28 @@ int64_t errorCounter = 0; }; +/// This is a special rewriter to be used in transform op implementations, +/// providing additional helper functions to update the transform state, etc. +// TODO: Helper functions will be added in a subsequent change. +class TransformRewriter : public RewriterBase { +protected: + friend class TransformState; + + /// Create a new TransformRewriter. + explicit TransformRewriter(MLIRContext *ctx, + ErrorCheckingTrackingListener *listener); + +public: + /// Return "true" if the tracking listener had failures. + bool hasTrackingFailures() const; + + /// Silence all tracking failures that have been encountered so far. + void silenceTrackingFailure(); + +private: + ErrorCheckingTrackingListener *const listener; +}; + /// This trait is supposed to be attached to Transform dialect operations that /// can be standalone top-level transforms. Such operations typically contain /// other Transform dialect operations that can be executed following some @@ -1064,7 +1088,8 @@ /// 5. If any `applyToOne` return silenceableFailure, the transformation is /// considered silenceable. /// 6. Otherwise the transformation is considered successful. - DiagnosedSilenceableFailure apply(TransformResults &transformResults, + DiagnosedSilenceableFailure apply(transform::TransformRewriter &rewriter, + TransformResults &transformResults, TransformState &state); /// Checks that the op matches the expectations of this trait. @@ -1213,6 +1238,15 @@ } }; +/// `TrackingListener` failures are reported only for ops that have this trait. +/// The purpose of this trait is to give users more time to update their custom +/// transform ops to use the provided `TransformRewriter` for all IR +/// modifications. This trait will eventually be removed, and failures will be +/// reported for all transform ops. +template +class ReportTrackingListenerFailuresOpTrait + : public OpTrait::TraitBase {}; + /// A single result of applying a transform op with `ApplyEachOpTrait` to a /// single payload operation. using ApplyToEachResult = MappedValue; @@ -1307,14 +1341,14 @@ /// `targets` contains operations of the same class and a silenceable failure /// is reported if it does not. template -DiagnosedSilenceableFailure -applyTransformToEach(TransformOpTy transformOp, Range &&targets, - SmallVectorImpl &results, - TransformState &state) { +DiagnosedSilenceableFailure applyTransformToEach( + TransformOpTy transformOp, TransformRewriter &rewriter, Range &&targets, + SmallVectorImpl &results, TransformState &state) { using OpTy = typename llvm::function_traits< - decltype(&TransformOpTy::applyToOne)>::template arg_t<0>; + decltype(&TransformOpTy::applyToOne)>::template arg_t<1>; static_assert(std::is_convertible::value, "expected transform function to take an operation"); + OpBuilder::InsertionGuard g(rewriter); SmallVector silenceableStack; unsigned expectedNumResults = transformOp->getNumResults(); @@ -1331,8 +1365,9 @@ ApplyToEachResultList partialResults; partialResults.reserve(expectedNumResults); Location specificOpLoc = specificOp->getLoc(); + rewriter.setInsertionPoint(specificOp); DiagnosedSilenceableFailure res = - transformOp.applyToOne(specificOp, partialResults, state); + transformOp.applyToOne(rewriter, specificOp, partialResults, state); if (res.isDefiniteFailure()) return DiagnosedSilenceableFailure::definiteFailure(); @@ -1367,7 +1402,8 @@ template mlir::DiagnosedSilenceableFailure mlir::transform::TransformEachOpTrait::apply( - TransformResults &transformResults, TransformState &state) { + TransformRewriter &rewriter, TransformResults &transformResults, + TransformState &state) { Value handle = this->getOperation()->getOperand(0); auto targets = state.getPayloadOps(handle); @@ -1403,7 +1439,7 @@ // corresponding results entry. SmallVector results; DiagnosedSilenceableFailure result = detail::applyTransformToEach( - cast(this->getOperation()), targets, results, state); + cast(this->getOperation()), rewriter, targets, results, state); // Step 3. Propagate the definite failure if any and bail out. if (result.isDefiniteFailure()) diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td @@ -41,10 +41,15 @@ transformation represented by the current op is targeted. Returns a special status object indicating whether the transformation succeeded or failed, and, if it failed, whether the failure is recoverable or not. + + IR must be created, modified and deleted with the provided rewriter. + implementations are responsible for setting the insertion point of the + rewriter to the desired location. }], /*returnType=*/"::mlir::DiagnosedSilenceableFailure", /*name=*/"apply", /*arguments=*/(ins + "::mlir::transform::TransformRewriter &":$rewriter, "::mlir::transform::TransformResults &":$transformResults, "::mlir::transform::TransformState &":$state )>, @@ -194,6 +199,10 @@ let cppNamespace = "::mlir::transform"; } +def ReportTrackingListenerFailuresOpTrait : NativeOpTrait<"ReportTrackingListenerFailuresOpTrait"> { + let cppNamespace = "::mlir::transform"; +} + def FindPayloadReplacementOpInterface : OpInterface<"FindPayloadReplacementOpInterface"> { let description = [{ diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td @@ -129,7 +129,8 @@ def ApplyPatternsOp : TransformDialectOp<"apply_patterns", [TransformOpInterface, TransformEachOpTrait, - DeclareOpInterfaceMethods] + DeclareOpInterfaceMethods, + ReportTrackingListenerFailuresOpTrait] # GraphRegionNoTerminator.traits> { let summary = "Greedily applies patterns to the body of the targeted op"; let description = [{ @@ -149,8 +150,8 @@ considered "payload op replacements". Furthermore, only if the replacement values are defined by the same op and that op has the same type as the original op, the mapping is updated. Otherwise, this transform fails - silently unless `fail_on_payload_replacement_not_found` is set to "false". - More details can be found at the documentation site of `TrackingListener`. + silently. More details can be found at the documentation site of + `TrackingListener`. This transform also fails silently if the pattern application did not converge within the default number of iterations/rewrites of the greedy @@ -158,8 +159,7 @@ }]; let arguments = (ins - TransformHandleTypeInterface:$target, - DefaultValuedAttr:$fail_on_payload_replacement_not_found); + TransformHandleTypeInterface:$target); let results = (outs); let regions = (region MaxSizedRegion<1>:$region); @@ -171,12 +171,12 @@ OpBuilder<(ins "Value":$target, CArg<"function_ref", "nullptr">: - $bodyBuilder, - CArg<"bool", "true">:$failOnPayloadReplacementNotFound)>, + $bodyBuilder)>, ]; let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::Operation *target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -222,6 +222,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::Operation *target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -238,6 +239,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::Operation *target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); diff --git a/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp b/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp --- a/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp +++ b/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp @@ -63,7 +63,8 @@ } // namespace DiagnosedSilenceableFailure -SimplifyBoundedAffineOpsOp::apply(TransformResults &results, +SimplifyBoundedAffineOpsOp::apply(transform::TransformRewriter &rewriter, + TransformResults &results, TransformState &state) { // Get constraints for bounded values. SmallVector lbs; @@ -127,6 +128,8 @@ SimplifyAffineMinMaxOp>(getContext(), cstr); FrozenRewritePatternSet frozenPatterns(std::move(patterns)); GreedyRewriteConfig config; + config.listener = + static_cast(rewriter.getListener()); config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps; // Apply the simplification pattern to a fixpoint. if (failed(applyOpPatternsAndFold(targets, frozenPatterns, config))) { diff --git a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp --- a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp +++ b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp @@ -26,7 +26,8 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::OneShotBufferizeOp::apply(TransformResults &transformResults, +transform::OneShotBufferizeOp::apply(transform::TransformRewriter &rewriter, + TransformResults &transformResults, TransformState &state) { OneShotBufferizationOptions options; options.allowReturnAllocs = getAllowReturnAllocs(); @@ -71,10 +72,9 @@ modifiesPayload(effects); } -DiagnosedSilenceableFailure -transform::EliminateEmptyTensorsOp::apply(TransformResults &transformResults, - TransformState &state) { - IRRewriter rewriter(getContext()); +DiagnosedSilenceableFailure transform::EliminateEmptyTensorsOp::apply( + transform::TransformRewriter &rewriter, TransformResults &transformResults, + TransformState &state) { OneShotBufferizationOptions options; options.allowReturnAllocs = true; @@ -95,11 +95,9 @@ // EmptyTensorToAllocTensorOp //===----------------------------------------------------------------------===// -DiagnosedSilenceableFailure -EmptyTensorToAllocTensorOp::applyToOne(tensor::EmptyOp target, - ApplyToEachResultList &results, - transform::TransformState &state) { - IRRewriter rewriter(target->getContext()); +DiagnosedSilenceableFailure EmptyTensorToAllocTensorOp::applyToOne( + transform::TransformRewriter &rewriter, tensor::EmptyOp target, + ApplyToEachResultList &results, transform::TransformState &state) { rewriter.setInsertionPoint(target); auto alloc = rewriter.replaceOpWithNewOp( target, target.getType(), target.getDynamicSizes()); diff --git a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp --- a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp +++ b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp @@ -398,7 +398,7 @@ /// Alter kernel configuration of the given kernel. static DiagnosedSilenceableFailure -alterGpuLaunch(IRRewriter &rewriter, LaunchOp gpuLaunch, +alterGpuLaunch(RewriterBase &rewriter, LaunchOp gpuLaunch, TransformOpInterface transformOp, std::optional gridDimX = std::nullopt, std::optional gridDimY = std::nullopt, @@ -661,12 +661,10 @@ return DiagnosedSilenceableFailure::success(); } -DiagnosedSilenceableFailure -transform::MapForallToBlocks::applyToOne(Operation *target, - ApplyToEachResultList &results, - transform::TransformState &state) { +DiagnosedSilenceableFailure transform::MapForallToBlocks::applyToOne( + transform::TransformRewriter &rewriter, Operation *target, + ApplyToEachResultList &results, transform::TransformState &state) { LaunchOp gpuLaunch = dyn_cast(target); - IRRewriter rewriter(getContext()); auto transformOp = cast(getOperation()); if (!getGenerateGpuLaunch() && !gpuLaunch) { @@ -856,7 +854,8 @@ } DiagnosedSilenceableFailure transform::MapNestedForallToThreads::applyToOne( - Operation *target, ApplyToEachResultList &results, TransformState &state) { + transform::TransformRewriter &rewriter, Operation *target, + ApplyToEachResultList &results, TransformState &state) { LaunchOp gpuLaunch = dyn_cast(target); auto transformOp = cast(getOperation()); @@ -877,7 +876,6 @@ // Set the GPU launch configuration for the block dims early, this is not // subject to IR inspection. - IRRewriter rewriter(getContext()); diag = alterGpuLaunch(rewriter, gpuLaunch, transformOp, std::nullopt, std::nullopt, std::nullopt, blockDims[0], blockDims[1], blockDims[2]); diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -169,13 +169,11 @@ // BufferizeToAllocationOp //===----------------------------------------------------------------------===// -DiagnosedSilenceableFailure -transform::BufferizeToAllocationOp::apply(transform::TransformResults &results, - transform::TransformState &state) { +DiagnosedSilenceableFailure transform::BufferizeToAllocationOp::apply( + transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { Attribute memorySpace = getMemorySpace().has_value() ? getMemorySpace().value() : Attribute(); - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); auto transformed = llvm::to_vector( llvm::map_range(state.getPayloadValues(getTarget()), [&](Value v) { return linalg::bufferizeToAllocation(rewriter, v, memorySpace); @@ -196,7 +194,8 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::DecomposeOp::applyToOne(LinalgOp target, +transform::DecomposeOp::applyToOne(transform::TransformRewriter &rewriter, + LinalgOp target, transform::ApplyToEachResultList &results, transform::TransformState &state) { #define DOWNSCALE(trans) \ @@ -286,7 +285,8 @@ } DiagnosedSilenceableFailure -transform::FuseOp::apply(mlir::transform::TransformResults &transformResults, +transform::FuseOp::apply(transform::TransformRewriter &rewriter, + mlir::transform::TransformResults &transformResults, mlir::transform::TransformState &state) { SmallVector tileSizes = extractFromI64ArrayAttr(getTileSizes()); SmallVector tileInterchange = @@ -297,8 +297,6 @@ tilingOptions = tilingOptions.setTileSizes(tileSizes); scf::SCFTileAndFuseOptions tileAndFuseOptions; tileAndFuseOptions.tilingOptions = tilingOptions; - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); LogicalResult result = applyTilingToAll( rewriter, getOperation(), state.getPayloadOps(getTarget()), tileSizes.size() - llvm::count(tileSizes, 0), transformResults, @@ -721,7 +719,8 @@ } DiagnosedSilenceableFailure -transform::FuseIntoContainingOp::apply(transform::TransformResults &results, +transform::FuseIntoContainingOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { SmallVector fusedOps; auto producerOps = state.getPayloadOps(getProducerOp()); @@ -764,8 +763,6 @@ return failure(); }; - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); while (!remainingProducers.empty()) { auto nextProducer = getNextProducer(); if (failed(nextProducer)) { @@ -842,7 +839,8 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::GeneralizeOp::applyToOne(LinalgOp target, +transform::GeneralizeOp::applyToOne(transform::TransformRewriter &rewriter, + LinalgOp target, transform::ApplyToEachResultList &results, transform::TransformState &state) { // Exit early if no transformation is needed. @@ -850,8 +848,6 @@ results.push_back(target); return DiagnosedSilenceableFailure::success(); } - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); rewriter.setInsertionPoint(target); FailureOr generic = generalizeNamedOp(rewriter, target); if (succeeded(generic)) { @@ -866,7 +862,8 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::InterchangeOp::applyToOne(GenericOp target, +transform::InterchangeOp::applyToOne(transform::TransformRewriter &rewriter, + GenericOp target, transform::ApplyToEachResultList &results, transform::TransformState &state) { ArrayRef interchangeVector = getIteratorInterchange(); @@ -875,8 +872,6 @@ results.push_back(target); return DiagnosedSilenceableFailure::success(); } - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); FailureOr res = interchangeGenericOp(rewriter, target, SmallVector(interchangeVector.begin(), @@ -904,10 +899,9 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::LowerPackOp::applyToOne( - tensor::PackOp target, transform::ApplyToEachResultList &transformResults, + transform::TransformRewriter &rewriter, tensor::PackOp target, + transform::ApplyToEachResultList &transformResults, transform::TransformState &state) { - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); rewriter.setInsertionPoint(target); FailureOr res = lowerPack(rewriter, target); if (failed(res)) { @@ -925,10 +919,9 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::LowerUnPackOp::applyToOne( - tensor::UnPackOp target, transform::ApplyToEachResultList &transformResults, + transform::TransformRewriter &rewriter, tensor::UnPackOp target, + transform::ApplyToEachResultList &transformResults, transform::TransformState &state) { - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); rewriter.setInsertionPoint(target); FailureOr res = lowerUnPack(rewriter, target); if (failed(res)) { @@ -964,7 +957,8 @@ } DiagnosedSilenceableFailure -transform::MatchOp::apply(transform::TransformResults &results, +transform::MatchOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { llvm::StringSet<> strs; if (getOps().has_value()) @@ -1053,8 +1047,8 @@ } DiagnosedSilenceableFailure transform::MultiTileSizesOp::applyToOne( - LinalgOp target, transform::ApplyToEachResultList &results, - TransformState &state) { + transform::TransformRewriter &rewriter, LinalgOp target, + transform::ApplyToEachResultList &results, TransformState &state) { if (isa(getLowSize().getType())) { if (target.hasDynamicShape()) { auto diag = emitSilenceableError() @@ -1155,7 +1149,8 @@ } DiagnosedSilenceableFailure -transform::PackOp::apply(transform::TransformResults &transformResults, +transform::PackOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &transformResults, transform::TransformState &state) { auto targetOps = state.getPayloadOps(getTarget()); // If nothing to pack, propagate success. @@ -1184,8 +1179,6 @@ DiagnosedSilenceableFailure status = unpackSingleIndexResultPayloadOperations( state, *this, packedSizes, getMixedPackedSizes()); - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); rewriter.setInsertionPoint(linalgOp); FailureOr maybeResult = pack(rewriter, linalgOp, packedSizes); if (failed(maybeResult)) @@ -1364,11 +1357,10 @@ } DiagnosedSilenceableFailure -PackGreedilyOp::apply(transform::TransformResults &transformResults, +PackGreedilyOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &transformResults, transform::TransformState &state) { SmallVector results; - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); for (Operation *op : state.getPayloadOps(getTarget())) { auto linalgOp = dyn_cast(op); if (!linalgOp) @@ -1464,7 +1456,8 @@ } DiagnosedSilenceableFailure -transform::PackTransposeOp::apply(transform::TransformResults &transformResults, +transform::PackTransposeOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &transformResults, transform::TransformState &state) { auto packOrUnpackOps = state.getPayloadOps(getTargetPackOrUnPackOp()); auto linalgOps = state.getPayloadOps(getTargetLinalgOp()); @@ -1542,8 +1535,6 @@ assert(packOp && linalgOp && "unexpected null op"); // Step 3. Actually transpose the ops. - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); FailureOr res = packTranspose( rewriter, packOp, linalgOp, unPackOp, getOuterPerm(), getInnerPerm()); // Preconditions have been checked, it is an error to fail here. @@ -1568,7 +1559,8 @@ //===---------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::PadOp::applyToOne(LinalgOp target, +transform::PadOp::applyToOne(transform::TransformRewriter &rewriter, + LinalgOp target, transform::ApplyToEachResultList &results, transform::TransformState &state) { // Convert the integer packing flags to booleans. @@ -1616,8 +1608,6 @@ transposePaddings.push_back( extractFromI64ArrayAttr(cast(transposeVector))); - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); LinalgOp paddedOp; SmallVector paddingDimensions = extractFromI64ArrayAttr(getPaddingDimensions()); @@ -1684,6 +1674,7 @@ //===---------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::HoistPadBuildPackingLoopNestOp::apply( + transform::TransformRewriter &rewriter, transform::TransformResults &transformResults, transform::TransformState &state) { auto targetOps = state.getPayloadOps(getTarget()); @@ -1700,8 +1691,6 @@ if (!padOp || !loopOp) return emitDefiniteFailure() << "requires exactly 2 non-null handles"; - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); FailureOr result = linalg::detail::buildPackingLoopNest(rewriter, padOp, loopOp, getTranspose()); @@ -1740,13 +1729,12 @@ } DiagnosedSilenceableFailure -transform::HoistPadOp::applyToOne(tensor::PadOp target, +transform::HoistPadOp::applyToOne(transform::TransformRewriter &rewriter, + tensor::PadOp target, transform::ApplyToEachResultList &results, transform::TransformState &state) { tensor::PadOp hoistedPadOp; SmallVector transposeOps; - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); FailureOr result = hoistPaddingOnTensors(rewriter, target, getNumLoops(), getTranspose(), hoistedPadOp, transposeOps); @@ -1779,7 +1767,8 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::PromoteOp::applyToOne(LinalgOp target, +transform::PromoteOp::applyToOne(transform::TransformRewriter &rewriter, + LinalgOp target, transform::ApplyToEachResultList &results, transform::TransformState &state) { LinalgPromotionOptions promotionOptions; @@ -1829,8 +1818,6 @@ if (failed(promoteSubviewsPrecondition(target, promotionOptions))) return emitDefaultDefiniteFailure(target); - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); rewriter.setInsertionPoint(target); FailureOr res = promoteSubViews(rewriter, target, promotionOptions); if (failed(res)) @@ -1844,7 +1831,8 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::ReplaceOp::apply(TransformResults &transformResults, +transform::ReplaceOp::apply(transform::TransformRewriter &rewriter, + TransformResults &transformResults, TransformState &state) { auto payload = state.getPayloadOps(getTarget()); @@ -1859,8 +1847,6 @@ } // Clone and replace. - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); Operation *pattern = &getBodyRegion().front().front(); SmallVector replacements; for (Operation *target : payload) { @@ -1904,7 +1890,8 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::ScalarizeOp::applyToOne(LinalgOp target, +transform::ScalarizeOp::applyToOne(transform::TransformRewriter &rewriter, + LinalgOp target, transform::ApplyToEachResultList &results, transform::TransformState &state) { scf::SCFTilingOptions tilingOptions; @@ -1916,8 +1903,6 @@ AffineMap map = target.getShapesToLoopsMap(); if (!map) return tileSizes; - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); SmallVector shapeSizes = affine::makeComposedFoldedMultiResultAffineApply(rewriter, loc, map, allShapeSizes); @@ -1931,8 +1916,6 @@ return tileSizes; }); SmallVector emptyTileSizes; - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); rewriter.setInsertionPoint(target); FailureOr maybeTilingResult = tileUsingSCFForOp( rewriter, cast(target.getOperation()), tilingOptions); @@ -1956,11 +1939,10 @@ DiagnosedSilenceableFailure transform::RewriteInDestinationPassingStyleOp::applyToOne( - Operation *target, transform::ApplyToEachResultList &results, + transform::TransformRewriter &rewriter, Operation *target, + transform::ApplyToEachResultList &results, transform::TransformState &state) { SmallVector res; - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); rewriter.setInsertionPoint(target); FailureOr maybeResult = TypeSwitch>(target) @@ -1978,13 +1960,12 @@ // SplitOp //===----------------------------------------------------------------------===// -DiagnosedSilenceableFailure SplitOp::apply(TransformResults &results, - TransformState &state) { +DiagnosedSilenceableFailure +SplitOp::apply(transform::TransformRewriter &rewriter, + TransformResults &results, TransformState &state) { // Collect the dynamic split points if provided. SmallVector payload = llvm::to_vector(state.getPayloadOps(getTarget())); - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); SmallVector splitPoints; splitPoints.reserve(payload.size()); if (getDynamicSplitPoint()) { @@ -2184,15 +2165,14 @@ } DiagnosedSilenceableFailure transform::SplitReductionOp::applyToOne( - LinalgOp target, transform::ApplyToEachResultList &results, + transform::TransformRewriter &rewriter, LinalgOp target, + transform::ApplyToEachResultList &results, transform::TransformState &state) { ControlSplitReductionFn splitFn = [&](LinalgOp) { return linalg::SplitReductionOptions{int64_t(getSplitFactor()), unsigned(getInsertSplitDimension()), bool(getInnerParallel())}; }; - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); rewriter.setInsertionPoint(target); FailureOr splitResult = (getUseScalingAlgorithm()) @@ -2230,10 +2210,9 @@ } DiagnosedSilenceableFailure transform::TileReductionUsingScfOp::applyToOne( - LinalgOp target, transform::ApplyToEachResultList &results, + transform::TransformRewriter &rewriter, LinalgOp target, + transform::ApplyToEachResultList &results, transform::TransformState &state) { - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); rewriter.setInsertionPoint(target); FailureOr result = scf::tileReductionUsingScf( rewriter, cast(target.getOperation()), @@ -2274,10 +2253,9 @@ } DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne( - LinalgOp target, transform::ApplyToEachResultList &results, + transform::TransformRewriter &rewriter, LinalgOp target, + transform::ApplyToEachResultList &results, transform::TransformState &state) { - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); rewriter.setInsertionPoint(target); SmallVector numThreads = getAsOpFoldResult(rewriter.getI64ArrayAttr(getNumThreads())); @@ -2363,7 +2341,8 @@ } DiagnosedSilenceableFailure -transform::TileOp::apply(TransformResults &transformResults, +transform::TileOp::apply(transform::TransformRewriter &rewriter, + TransformResults &transformResults, TransformState &state) { ArrayRef tileSizes = getStaticSizes(); @@ -2478,8 +2457,6 @@ } tilingOptions.setInterchange(getInterchange()); - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); FailureOr maybeTilingResult = tileUsingSCFForOp(rewriter, tilingInterface, tilingOptions); if (failed(maybeTilingResult)) @@ -2720,45 +2697,44 @@ } DiagnosedSilenceableFailure -transform::TileToForallOp::apply(transform::TransformResults &transformResults, +transform::TileToForallOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &transformResults, transform::TransformState &state) { - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); - auto transformOp = cast(getOperation()); - - // Result payload ops. - SmallVector tileOps; - SmallVector tiledOps; - - // Unpack handles. - SmallVector mixedNumThreads; - DiagnosedSilenceableFailure status = - getPackedNumThreads() - ? unpackSingleIndexResultPayloadOperations( - state, transformOp, mixedNumThreads, getPackedNumThreads()) - : unpackSingleIndexResultPayloadOperations( - state, transformOp, mixedNumThreads, getMixedNumThreads()); - if (!status.succeeded()) - return status; - SmallVector mixedTileSizes; - status = getPackedTileSizes() - ? unpackSingleIndexResultPayloadOperations( - state, transformOp, mixedTileSizes, getPackedTileSizes()) - : unpackSingleIndexResultPayloadOperations( - state, transformOp, mixedTileSizes, getMixedTileSizes()); - if (!status.succeeded()) - return status; - - for (Operation *target : state.getPayloadOps(getTarget())) { - linalg::ForallTilingResult tilingResult; - DiagnosedSilenceableFailure diag = tileToForallOpImpl( - rewriter, state, transformOp, target, mixedNumThreads, mixedTileSizes, - getMapping(), tilingResult); - if (!diag.succeeded()) + auto transformOp = cast(getOperation()); + + // Result payload ops. + SmallVector tileOps; + SmallVector tiledOps; + + // Unpack handles. + SmallVector mixedNumThreads; + DiagnosedSilenceableFailure status = + getPackedNumThreads() + ? unpackSingleIndexResultPayloadOperations( + state, transformOp, mixedNumThreads, getPackedNumThreads()) + : unpackSingleIndexResultPayloadOperations( + state, transformOp, mixedNumThreads, getMixedNumThreads()); + if (!status.succeeded()) + return status; + SmallVector mixedTileSizes; + status = getPackedTileSizes() + ? unpackSingleIndexResultPayloadOperations( + state, transformOp, mixedTileSizes, getPackedTileSizes()) + : unpackSingleIndexResultPayloadOperations( + state, transformOp, mixedTileSizes, getMixedTileSizes()); + if (!status.succeeded()) + return status; + + for (Operation *target : state.getPayloadOps(getTarget())) { + linalg::ForallTilingResult tilingResult; + DiagnosedSilenceableFailure diag = tileToForallOpImpl( + rewriter, state, transformOp, target, mixedNumThreads, mixedTileSizes, + getMapping(), tilingResult); + if (!diag.succeeded()) return diag; tileOps.push_back(tilingResult.tileOp); tiledOps.push_back(tilingResult.tiledOp); - } + } transformResults.set(cast(getForallOp()), tileOps); transformResults.set(cast(getTiledOp()), tiledOps); @@ -2833,7 +2809,8 @@ } DiagnosedSilenceableFailure -transform::TileToScfForOp::apply(TransformResults &transformResults, +transform::TileToScfForOp::apply(transform::TransformRewriter &rewriter, + TransformResults &transformResults, TransformState &state) { ArrayRef tileSizes = getStaticSizes(); @@ -2902,8 +2879,6 @@ } tilingOptions.setInterchange(getInterchange()); - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); FailureOr tilingResult = tileUsingSCFForOp(rewriter, tilingInterfaceOp, tilingOptions); if (failed(tilingResult)) @@ -3047,7 +3022,8 @@ } // namespace DiagnosedSilenceableFailure -transform::VectorizeOp::applyToOne(Operation *target, +transform::VectorizeOp::applyToOne(transform::TransformRewriter &rewriter, + Operation *target, transform::ApplyToEachResultList &results, transform::TransformState &state) { if (!target->hasTrait()) { @@ -3094,10 +3070,9 @@ // MaskedVectorizeOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::MaskedVectorizeOp::apply( + transform::TransformRewriter &rewriter, mlir::transform::TransformResults &transformResults, mlir::transform::TransformState &state) { - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); auto targets = state.getPayloadOps(getTarget()); if (std::empty(targets)) return DiagnosedSilenceableFailure::success(); @@ -3175,7 +3150,8 @@ DiagnosedSilenceableFailure transform::HoistRedundantVectorTransfersOp::applyToOne( - func::FuncOp target, transform::ApplyToEachResultList &results, + transform::TransformRewriter &rewriter, func::FuncOp target, + transform::ApplyToEachResultList &results, transform::TransformState &state) { // WARNING: This hoisting does not model parallelism and is generally // incorrect when used on distributed loops with memref semantics! @@ -3190,10 +3166,9 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::ConvertConv2DToImg2ColOp::applyToOne( - linalg::LinalgOp target, transform::ApplyToEachResultList &results, + transform::TransformRewriter &rewriter, linalg::LinalgOp target, + transform::ApplyToEachResultList &results, transform::TransformState &state) { - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); rewriter.setInsertionPoint(target); auto maybeTransformed = TypeSwitch>>( @@ -3225,10 +3200,9 @@ DiagnosedSilenceableFailure transform::HoistRedundantTensorSubsetsOp::applyToOne( - Operation *target, transform::ApplyToEachResultList &results, + transform::TransformRewriter &rewriter, Operation *target, + transform::ApplyToEachResultList &results, transform::TransformState &state) { - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); auto forOp = dyn_cast(target); if (forOp) { linalg::hoistRedundantSubsetExtractInsert(rewriter, forOp); @@ -3291,11 +3265,10 @@ } DiagnosedSilenceableFailure transform::InsertSliceToCopyOp::applyToOne( - Operation *targetOp, transform::ApplyToEachResultList &results, + transform::TransformRewriter &rewriter, Operation *targetOp, + transform::ApplyToEachResultList &results, transform::TransformState &state) { - TrackingListener listener(state, *this); - IRRewriter rewriter(targetOp->getContext(), &listener); rewriter.setInsertionPoint(targetOp); if (auto target = dyn_cast(targetOp)) return doit(rewriter, target, results, state); diff --git a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp --- a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp +++ b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp @@ -60,10 +60,10 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::MemRefMultiBufferOp::apply( + transform::TransformRewriter &rewriter, transform::TransformResults &transformResults, transform::TransformState &state) { SmallVector results; - IRRewriter rewriter(getContext()); for (Operation *op : state.getPayloadOps(getTarget())) { bool canApplyMultiBuffer = true; auto target = cast(op); @@ -105,7 +105,8 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::MemRefMakeLoopIndependentOp::applyToOne( - Operation *target, transform::ApplyToEachResultList &results, + transform::TransformRewriter &rewriter, Operation *target, + transform::ApplyToEachResultList &results, transform::TransformState &state) { // Gather IVs. SmallVector ivs; @@ -123,7 +124,6 @@ } // Rewrite IR. - IRRewriter rewriter(target->getContext()); FailureOr replacement = failure(); if (auto allocaOp = dyn_cast(target)) { replacement = memref::replaceWithIndependentOp(rewriter, allocaOp, ivs); diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp --- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp +++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp @@ -35,7 +35,8 @@ // GetParentForOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::GetParentForOp::apply(transform::TransformResults &results, +transform::GetParentForOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { SetVector parents; for (Operation *target : state.getPayloadOps(getTarget())) { @@ -92,7 +93,8 @@ } DiagnosedSilenceableFailure -transform::LoopOutlineOp::apply(transform::TransformResults &results, +transform::LoopOutlineOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { SmallVector functions; SmallVector calls; @@ -100,7 +102,6 @@ for (Operation *target : state.getPayloadOps(getTarget())) { Location location = target->getLoc(); Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(target); - IRRewriter rewriter(getContext()); scf::ExecuteRegionOp exec = wrapInExecuteRegion(rewriter, target); if (!exec) { DiagnosedSilenceableFailure diag = emitSilenceableError() @@ -135,11 +136,11 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::LoopPeelOp::applyToOne(scf::ForOp target, +transform::LoopPeelOp::applyToOne(transform::TransformRewriter &rewriter, + scf::ForOp target, transform::ApplyToEachResultList &results, transform::TransformState &state) { scf::ForOp result; - IRRewriter rewriter(target->getContext()); // This helper returns failure when peeling does not occur (i.e. when the IR // is not modified). This is not a failure for the op as the postcondition: // "the loop trip count is divisible by the step" @@ -192,7 +193,8 @@ } DiagnosedSilenceableFailure -transform::LoopPipelineOp::applyToOne(scf::ForOp target, +transform::LoopPipelineOp::applyToOne(transform::TransformRewriter &rewriter, + scf::ForOp target, transform::ApplyToEachResultList &results, transform::TransformState &state) { scf::PipeliningOption options; @@ -203,7 +205,6 @@ getReadLatency()); }; scf::ForLoopPipeliningPattern pattern(options, target->getContext()); - IRRewriter rewriter(getContext()); rewriter.setInsertionPoint(target); FailureOr patternResult = scf::pipelineForLoop(rewriter, target, options); @@ -219,7 +220,8 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::LoopUnrollOp::applyToOne(Operation *op, +transform::LoopUnrollOp::applyToOne(transform::TransformRewriter &rewriter, + Operation *op, transform::ApplyToEachResultList &results, transform::TransformState &state) { LogicalResult result(failure()); @@ -241,7 +243,8 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::LoopCoalesceOp::applyToOne(Operation *op, +transform::LoopCoalesceOp::applyToOne(transform::TransformRewriter &rewriter, + Operation *op, transform::ApplyToEachResultList &results, transform::TransformState &state) { LogicalResult result(failure()); @@ -276,12 +279,10 @@ } DiagnosedSilenceableFailure transform::TakeAssumedBranchOp::applyToOne( - scf::IfOp ifOp, transform::ApplyToEachResultList &results, + transform::TransformRewriter &rewriter, scf::IfOp ifOp, + transform::ApplyToEachResultList &results, transform::TransformState &state) { - TrackingListener listener(state, *this); - IRRewriter rewriter(ifOp->getContext(), &listener); rewriter.setInsertionPoint(ifOp); - Region ®ion = getTakeElseBranch() ? ifOp.getElseRegion() : ifOp.getThenRegion(); if (!llvm::hasSingleElement(region)) { diff --git a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp --- a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp +++ b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp @@ -123,7 +123,8 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::MakeLoopIndependentOp::applyToOne( - Operation *target, transform::ApplyToEachResultList &results, + transform::TransformRewriter &rewriter, Operation *target, + transform::ApplyToEachResultList &results, transform::TransformState &state) { // Gather IVs. SmallVector ivs; @@ -141,7 +142,6 @@ } // Rewrite IR. - IRRewriter rewriter(target->getContext()); FailureOr replacement = failure(); if (auto padOp = dyn_cast(target)) { replacement = tensor::buildIndependentOp(rewriter, padOp, ivs); diff --git a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp @@ -157,6 +157,13 @@ } return success(); } + if (attribute.getName().getValue() == kSilenceTrackingFailuresAttrName) { + if (!llvm::isa(attribute.getValue())) { + return op->emitError() + << attribute.getName() << " must be a unit attribute"; + } + return success(); + } return emitError(op->getLoc()) << "unknown attribute: " << attribute.getName(); } diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp @@ -8,6 +8,7 @@ #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformTypes.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Operation.h" @@ -896,12 +897,41 @@ return diag; } + // Prepare rewriter and listener. + transform::ErrorCheckingTrackingListener trackingListener(*this, transform); + transform::TransformRewriter rewriter(transform->getContext(), + &trackingListener); + // Compute the result but do not short-circuit the silenceable failure case as // we still want the handles to propagate properly so the "suppress" mode can // proceed on a best effort basis. transform::TransformResults results(transform->getNumResults()); - DiagnosedSilenceableFailure result(transform.apply(results, *this)); + DiagnosedSilenceableFailure result(transform.apply(rewriter, results, *this)); compactOpHandles(); + + // Error handling: fail if transform or listener failed. + DiagnosedSilenceableFailure trackingFailure = + trackingListener.checkAndResetError(); + if (!transform->hasTrait() || + transform->hasAttr( + transform::TransformDialect::kSilenceTrackingFailuresAttrName)) { + // Only report failures for ReportTrackingListenerFailuresOpTrait ops. Also + // do not report failures if the above mentioned attribute is set. + if (trackingFailure.isSilenceableFailure()) + (void)trackingFailure.silence(); + trackingFailure = DiagnosedSilenceableFailure::success(); + } + if (!trackingFailure.succeeded()) { + if (result.succeeded()) { + result = std::move(trackingFailure); + } else { + // Transform op errors have precedence, report those first. + if (result.isSilenceableFailure()) + result.attachNote() << "tracking listener also failed: " + << trackingFailure.getMessage(); + (void)trackingFailure.silence(); + } + } if (result.isDefiniteFailure()) return result; @@ -1161,6 +1191,16 @@ // TrackingListener //===----------------------------------------------------------------------===// +transform::TrackingListener::TrackingListener(TransformState &state, + TransformOpInterface op) + : TransformState::Extension(state), transformOp(op) { + if (op) { + for (OpOperand *opOperand : transformOp.getConsumedHandleOpOperands()) { + consumedHandles.insert(opOperand->get()); + } + } +} + Operation *transform::TrackingListener::getCommonDefiningOp(ValueRange values) { Operation *defOp = nullptr; for (Value v : values) { @@ -1267,15 +1307,34 @@ // Op is not tracked. return; } + + // Helper function to check if the current transform op consumes any handle + // that is mapped to `op`. + // + // Note: If a handle was consumed, there shouldn't be any alive users, so it + // is not really necessary to check for consumed handles. However, in case + // there are indeed alive handles that were consumed (which is undefined + // behavior) and a replacement op could not be found, we want to fail with a + // nicer error message: "op uses a handle invalidated..." instead of "could + // not find replacement op". This nicer error is produced later. + auto handleWasConsumed = [&] { + return llvm::any_of(opHandles, + [&](Value h) { return consumedHandles.contains(h); }); + }; + + // Helper function to check if the handle is alive. auto hasAliveUser = [&]() { - for (Value v : opHandles) + for (Value v : opHandles) { for (Operation *user : v.getUsers()) - if (!happensBefore(user, transformOp)) + if (user != transformOp && !happensBefore(user, transformOp)) return true; + } return false; }; - if (!hasAliveUser()) { - // The op is tracked but the corresponding handles are dead. + + if (!hasAliveUser() || handleWasConsumed()) { + // The op is tracked but the corresponding handles are dead or were + // consumed. Drop the op form the mapping. (void)replacePayloadOp(op, nullptr); return; } @@ -1326,6 +1385,28 @@ ++errorCounter; } +//===----------------------------------------------------------------------===// +// TransformRewriter +//===----------------------------------------------------------------------===// + +transform::TransformRewriter::TransformRewriter( + MLIRContext *ctx, ErrorCheckingTrackingListener *listener) + : RewriterBase(ctx), listener(listener) { + setListener(listener); +} + +bool transform::TransformRewriter::hasTrackingFailures() const { + return listener->failed(); +} + +/// Silence all tracking failures that have been encountered so far. +void transform::TransformRewriter::silenceTrackingFailure() { + if (hasTrackingFailures()) { + DiagnosedSilenceableFailure status = listener->checkAndResetError(); + (void)status.silence(); + } +} + //===----------------------------------------------------------------------===// // Utilities for TransformEachOpTrait. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -1,4 +1,4 @@ -//===- TransformDialect.cpp - Transform dialect operations ----------------===// +//===- TransformOps.cpp - Transform dialect operations --------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Transform/IR/TransformOps.h" + #include "mlir/Dialect/Transform/IR/MatchInterfaces.h" #include "mlir/Dialect/Transform/IR/TransformAttrs.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" @@ -115,7 +116,8 @@ } DiagnosedSilenceableFailure -transform::AlternativesOp::apply(transform::TransformResults &results, +transform::AlternativesOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { SmallVector originals; if (Value scopeHandle = getScope()) @@ -222,7 +224,8 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::AnnotateOp::apply(transform::TransformResults &results, +transform::AnnotateOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { SmallVector targets = llvm::to_vector(state.getPayloadOps(getTarget())); @@ -258,10 +261,9 @@ // ApplyPatternsOp //===----------------------------------------------------------------------===// -DiagnosedSilenceableFailure -transform::ApplyPatternsOp::applyToOne(Operation *target, - ApplyToEachResultList &results, - transform::TransformState &state) { +DiagnosedSilenceableFailure transform::ApplyPatternsOp::applyToOne( + transform::TransformRewriter &rewriter, Operation *target, + ApplyToEachResultList &results, transform::TransformState &state) { // Make sure that this transform is not applied to itself. Modifying the // transform IR while it is being interpreted is generally dangerous. Even // more so for the ApplyPatternsOp because the GreedyPatternRewriteDriver @@ -282,9 +284,9 @@ } // Configure the GreedyPatternRewriteDriver. - ErrorCheckingTrackingListener listener(state, *this); GreedyRewriteConfig config; - config.listener = &listener; + config.listener = + static_cast(rewriter.getListener()); LogicalResult result = failure(); if (target->hasTrait()) { @@ -312,14 +314,6 @@ << "greedy pattern application failed"; } - // Check listener state for tracking errors. - if (listener.failed()) { - DiagnosedSilenceableFailure status = listener.checkAndResetError(); - if (getFailOnPayloadReplacementNotFound()) - return status; - (void)status.silence(); - } - return DiagnosedSilenceableFailure::success(); } @@ -346,12 +340,8 @@ void transform::ApplyPatternsOp::build( OpBuilder &builder, OperationState &result, Value target, - function_ref bodyBuilder, - bool failOnPayloadReplacementNotFound) { + function_ref bodyBuilder) { result.addOperands(target); - result.getOrAddProperties() - .fail_on_payload_replacement_not_found = - builder.getBoolAttr(failOnPayloadReplacementNotFound); OpBuilder::InsertionGuard g(builder); Region *region = result.addRegion(); @@ -377,10 +367,9 @@ // ApplyRegisteredPassOp //===----------------------------------------------------------------------===// -DiagnosedSilenceableFailure -transform::ApplyRegisteredPassOp::applyToOne(Operation *target, - ApplyToEachResultList &results, - transform::TransformState &state) { +DiagnosedSilenceableFailure transform::ApplyRegisteredPassOp::applyToOne( + transform::TransformRewriter &rewriter, Operation *target, + ApplyToEachResultList &results, transform::TransformState &state) { // Make sure that this transform is not applied to itself. Modifying the // transform IR while it is being interpreted is generally dangerous. Even // more so when applying passes because they may perform a wide range of IR @@ -420,7 +409,8 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::CastOp::applyToOne(Operation *target, ApplyToEachResultList &results, +transform::CastOp::applyToOne(transform::TransformRewriter &rewriter, + Operation *target, ApplyToEachResultList &results, transform::TransformState &state) { results.push_back(target); return DiagnosedSilenceableFailure::success(); @@ -482,7 +472,8 @@ } DiagnosedSilenceableFailure -transform::ForeachMatchOp::apply(transform::TransformResults &results, +transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { SmallVector> matchActionPairs; @@ -780,7 +771,8 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::ForeachOp::apply(transform::TransformResults &results, +transform::ForeachOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { SmallVector> resultOps(getNumResults(), {}); @@ -869,6 +861,7 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::GetClosestIsolatedParentOp::apply( + transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { SetVector parents; for (Operation *target : state.getPayloadOps(getTarget())) { @@ -892,7 +885,8 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::GetConsumersOfResult::apply(transform::TransformResults &results, +transform::GetConsumersOfResult::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { int64_t resultNumber = getResultNumber(); auto payloadOps = state.getPayloadOps(getTarget()); @@ -917,7 +911,8 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::GetDefiningOp::apply(transform::TransformResults &results, +transform::GetDefiningOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { SmallVector definingOps; for (Value v : state.getPayloadValues(getTarget())) { @@ -938,7 +933,8 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::GetProducerOfOperand::apply(transform::TransformResults &results, +transform::GetProducerOfOperand::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { int64_t operandNumber = getOperandNumber(); SmallVector producers; @@ -966,7 +962,8 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::GetResultOp::apply(transform::TransformResults &results, +transform::GetResultOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { int64_t resultNumber = getResultNumber(); SmallVector opResults; @@ -1017,7 +1014,8 @@ } DiagnosedSilenceableFailure -transform::IncludeOp::apply(transform::TransformResults &results, +transform::IncludeOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { auto callee = SymbolTable::lookupNearestSymbolFrom( getOperation(), getTarget()); @@ -1155,7 +1153,8 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::MatchParamCmpIOp::apply(transform::TransformResults &results, +transform::MatchParamCmpIOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { auto signedAPIntAsString = [&](APInt value) { std::string str; @@ -1241,7 +1240,8 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::ParamConstantOp::apply(transform::TransformResults &results, +transform::ParamConstantOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { results.setParams(cast(getParam()), {getValue()}); return DiagnosedSilenceableFailure::success(); @@ -1252,7 +1252,8 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::MergeHandlesOp::apply(transform::TransformResults &results, +transform::MergeHandlesOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { SmallVector operations; for (Value operand : getHandles()) @@ -1295,7 +1296,8 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::NamedSequenceOp::apply(transform::TransformResults &results, +transform::NamedSequenceOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { // Nothing to do here. return DiagnosedSilenceableFailure::success(); @@ -1461,7 +1463,8 @@ } DiagnosedSilenceableFailure -transform::SplitHandleOp::apply(transform::TransformResults &results, +transform::SplitHandleOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { int64_t numPayloadOps = llvm::range_size(state.getPayloadOps(getHandle())); auto produceNumOpsError = [&]() { @@ -1522,7 +1525,8 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::ReplicateOp::apply(transform::TransformResults &results, +transform::ReplicateOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { unsigned numRepetitions = llvm::range_size(state.getPayloadOps(getPattern())); for (const auto &en : llvm::enumerate(getHandles())) { @@ -1562,7 +1566,8 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::SequenceOp::apply(transform::TransformResults &results, +transform::SequenceOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { // Map the entry block argument to the list of operations. auto scope = state.make_region_scope(*getBodyBlock()->getParent()); @@ -1852,7 +1857,8 @@ } DiagnosedSilenceableFailure -transform::PrintOp::apply(transform::TransformResults &results, +transform::PrintOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { llvm::outs() << "[[[ IR printer: "; if (getName().has_value()) diff --git a/mlir/lib/Dialect/Transform/PDLExtension/PDLExtensionOps.cpp b/mlir/lib/Dialect/Transform/PDLExtension/PDLExtensionOps.cpp --- a/mlir/lib/Dialect/Transform/PDLExtension/PDLExtensionOps.cpp +++ b/mlir/lib/Dialect/Transform/PDLExtension/PDLExtensionOps.cpp @@ -138,7 +138,8 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::PDLMatchOp::apply(transform::TransformResults &results, +transform::PDLMatchOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { auto *extension = state.getExtension(); assert(extension && @@ -167,7 +168,8 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::WithPDLPatternsOp::apply(transform::TransformResults &results, +transform::WithPDLPatternsOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { TransformOpInterface transformOp = nullptr; for (Operation &nested : getBody().front()) { diff --git a/mlir/test/Dialect/Transform/test-pattern-application.mlir b/mlir/test/Dialect/Transform/test-pattern-application.mlir --- a/mlir/test/Dialect/Transform/test-pattern-application.mlir +++ b/mlir/test/Dialect/Transform/test-pattern-application.mlir @@ -88,7 +88,7 @@ %1 = transform.structured.match ops{["test.foo"]} in %arg1 : (!transform.any_op) -> !transform.any_op transform.apply_patterns to %0 { transform.apply_patterns.transform.test_patterns - } {fail_on_payload_replacement_not_found = false} : !transform.any_op + } {transform.silence_tracking_failures} : !transform.any_op transform.annotate %1 "annotated" : !transform.any_op } diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp @@ -45,7 +45,8 @@ return llvm::StringLiteral("transform.test_transform_op"); } - DiagnosedSilenceableFailure apply(transform::TransformResults &results, + DiagnosedSilenceableFailure apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { InFlightDiagnostic remark = emitRemark() << "applying transformation"; if (Attribute message = getMessage()) @@ -98,7 +99,8 @@ "transform.test_transform_unrestricted_op_no_interface"); } - DiagnosedSilenceableFailure apply(transform::TransformResults &results, + DiagnosedSilenceableFailure apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { return DiagnosedSilenceableFailure::success(); } @@ -110,6 +112,7 @@ DiagnosedSilenceableFailure mlir::test::TestProduceSelfHandleOrForwardOperandOp::apply( + transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { if (getOperation()->getNumOperands() != 0) { results.set(cast(getResult()), @@ -129,6 +132,7 @@ DiagnosedSilenceableFailure mlir::test::TestProduceValueHandleToSelfOperand::apply( + transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { results.setValues(llvm::cast(getOut()), getIn()); return DiagnosedSilenceableFailure::success(); @@ -143,7 +147,8 @@ DiagnosedSilenceableFailure mlir::test::TestProduceValueHandleToResult::applyToOne( - Operation *target, transform::ApplyToEachResultList &results, + transform::TransformRewriter &rewriter, Operation *target, + transform::ApplyToEachResultList &results, transform::TransformState &state) { if (target->getNumResults() <= getNumber()) return emitSilenceableError() << "payload has no result #" << getNumber(); @@ -160,7 +165,8 @@ DiagnosedSilenceableFailure mlir::test::TestProduceValueHandleToArgumentOfParentBlock::applyToOne( - Operation *target, transform::ApplyToEachResultList &results, + transform::TransformRewriter &rewriter, Operation *target, + transform::ApplyToEachResultList &results, transform::TransformState &state) { if (!target->getBlock()) return emitSilenceableError() << "payload has no parent block"; @@ -183,7 +189,8 @@ } DiagnosedSilenceableFailure -mlir::test::TestConsumeOperand::apply(transform::TransformResults &results, +mlir::test::TestConsumeOperand::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { return DiagnosedSilenceableFailure::success(); } @@ -197,6 +204,7 @@ } DiagnosedSilenceableFailure mlir::test::TestConsumeOperandOfOpKindOrFail::apply( + transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { auto payload = state.getPayloadOps(getOperand()); assert(llvm::hasSingleElement(payload) && "expected a single target op"); @@ -237,6 +245,7 @@ } DiagnosedSilenceableFailure mlir::test::TestPrintRemarkAtOperandOp::apply( + transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { auto payload = state.getPayloadOps(getOperand()); for (Operation *op : payload) @@ -252,6 +261,7 @@ } DiagnosedSilenceableFailure mlir::test::TestPrintRemarkAtOperandValue::apply( + transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { ArrayRef values = state.getPayloadValues(getIn()); for (Value value : values) { @@ -277,15 +287,16 @@ transform::onlyReadsPayload(effects); } -DiagnosedSilenceableFailure -mlir::test::TestAddTestExtensionOp::apply(transform::TransformResults &results, - transform::TransformState &state) { +DiagnosedSilenceableFailure mlir::test::TestAddTestExtensionOp::apply( + transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { state.addExtension(getMessageAttr()); return DiagnosedSilenceableFailure::success(); } DiagnosedSilenceableFailure mlir::test::TestCheckIfTestExtensionPresentOp::apply( + transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { auto *extension = state.getExtension(); if (!extension) { @@ -316,6 +327,7 @@ } DiagnosedSilenceableFailure mlir::test::TestRemapOperandPayloadToSelfOp::apply( + transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { auto *extension = state.getExtension(); if (!extension) @@ -337,14 +349,15 @@ } DiagnosedSilenceableFailure mlir::test::TestRemoveTestExtensionOp::apply( + transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { state.removeExtension(); return DiagnosedSilenceableFailure::success(); } -DiagnosedSilenceableFailure -mlir::test::TestReversePayloadOpsOp::apply(transform::TransformResults &results, - transform::TransformState &state) { +DiagnosedSilenceableFailure mlir::test::TestReversePayloadOpsOp::apply( + transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { auto payloadOps = state.getPayloadOps(getTarget()); auto reversedOps = llvm::to_vector(llvm::reverse(payloadOps)); results.set(llvm::cast(getResult()), reversedOps); @@ -352,6 +365,7 @@ } DiagnosedSilenceableFailure mlir::test::TestTransformOpWithRegions::apply( + transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { return DiagnosedSilenceableFailure::success(); } @@ -361,6 +375,7 @@ DiagnosedSilenceableFailure mlir::test::TestBranchingTransformOpTerminator::apply( + transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { return DiagnosedSilenceableFailure::success(); } @@ -369,6 +384,7 @@ SmallVectorImpl &effects) {} DiagnosedSilenceableFailure mlir::test::TestEmitRemarkAndEraseOperandOp::apply( + transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { emitRemark() << getRemark(); for (Operation *op : state.getPayloadOps(getTarget())) @@ -386,7 +402,8 @@ } DiagnosedSilenceableFailure mlir::test::TestWrongNumberOfResultsOp::applyToOne( - Operation *target, transform::ApplyToEachResultList &results, + transform::TransformRewriter &rewriter, Operation *target, + transform::ApplyToEachResultList &results, transform::TransformState &state) { OperationState opState(target->getLoc(), "foo"); results.push_back(OpBuilder(target).create(opState)); @@ -395,7 +412,8 @@ DiagnosedSilenceableFailure mlir::test::TestWrongNumberOfMultiResultsOp::applyToOne( - Operation *target, transform::ApplyToEachResultList &results, + transform::TransformRewriter &rewriter, Operation *target, + transform::ApplyToEachResultList &results, transform::TransformState &state) { static int count = 0; if (count++ == 0) { @@ -407,7 +425,8 @@ DiagnosedSilenceableFailure mlir::test::TestCorrectNumberOfMultiResultsOp::applyToOne( - Operation *target, transform::ApplyToEachResultList &results, + transform::TransformRewriter &rewriter, Operation *target, + transform::ApplyToEachResultList &results, transform::TransformState &state) { OperationState opState(target->getLoc(), "foo"); results.push_back(OpBuilder(target).create(opState)); @@ -417,7 +436,8 @@ DiagnosedSilenceableFailure mlir::test::TestMixedNullAndNonNullResultsOp::applyToOne( - Operation *target, transform::ApplyToEachResultList &results, + transform::TransformRewriter &rewriter, Operation *target, + transform::ApplyToEachResultList &results, transform::TransformState &state) { OperationState opState(target->getLoc(), "foo"); results.push_back(nullptr); @@ -427,7 +447,8 @@ DiagnosedSilenceableFailure mlir::test::TestMixedSuccessAndSilenceableOp::applyToOne( - Operation *target, transform::ApplyToEachResultList &results, + transform::TransformRewriter &rewriter, Operation *target, + transform::ApplyToEachResultList &results, transform::TransformState &state) { if (target->hasAttr("target_me")) return DiagnosedSilenceableFailure::success(); @@ -436,6 +457,7 @@ DiagnosedSilenceableFailure mlir::test::TestPrintNumberOfAssociatedPayloadIROps::apply( + transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { if (!getHandle()) emitRemark() << 0; @@ -449,7 +471,8 @@ } DiagnosedSilenceableFailure -mlir::test::TestCopyPayloadOp::apply(transform::TransformResults &results, +mlir::test::TestCopyPayloadOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { results.set(llvm::cast(getCopy()), state.getPayloadOps(getHandle())); @@ -498,6 +521,7 @@ DiagnosedSilenceableFailure mlir::test::TestReportNumberOfTrackedHandlesNestedUnder::apply( + transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { int64_t count = 0; for (Operation *op : state.getPayloadOps(getTarget())) { @@ -520,7 +544,8 @@ } DiagnosedSilenceableFailure -mlir::test::TestPrintParamOp::apply(transform::TransformResults &results, +mlir::test::TestPrintParamOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { std::string str; llvm::raw_string_ostream os(str); @@ -537,7 +562,8 @@ } DiagnosedSilenceableFailure -mlir::test::TestAddToParamOp::apply(transform::TransformResults &results, +mlir::test::TestAddToParamOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { SmallVector values(/*Size=*/1, /*Value=*/0); if (Value param = getParam()) { @@ -559,6 +585,7 @@ DiagnosedSilenceableFailure mlir::test::TestProduceParamWithNumberOfTestOps::apply( + transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { Builder builder(getContext()); SmallVector result = llvm::to_vector( @@ -577,6 +604,7 @@ DiagnosedSilenceableFailure mlir::test::TestProduceIntegerParamWithTypeOp::apply( + transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { Attribute zero = IntegerAttr::get(getType(), 0); results.setParams(llvm::cast(getResult()), zero); @@ -599,7 +627,8 @@ DiagnosedSilenceableFailure mlir::test::TestProduceTransformParamOrForwardOperandOp::applyToOne( - Operation *target, ::transform::ApplyToEachResultList &results, + transform::TransformRewriter &rewriter, Operation *target, + ::transform::ApplyToEachResultList &results, ::transform::TransformState &state) { Builder builder(getContext()); if (getFirstResultIsParam()) { @@ -625,6 +654,7 @@ } DiagnosedSilenceableFailure mlir::test::TestProduceNullPayloadOp::apply( + transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { SmallVector null({nullptr}); results.set(llvm::cast(getOut()), null); @@ -632,6 +662,7 @@ } DiagnosedSilenceableFailure mlir::test::TestProduceEmptyPayloadOp::apply( + transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { results.set(cast(getOut()), {}); return DiagnosedSilenceableFailure::success(); @@ -642,9 +673,9 @@ transform::producesHandle(getOut(), effects); } -DiagnosedSilenceableFailure -mlir::test::TestProduceNullParamOp::apply(transform::TransformResults &results, - transform::TransformState &state) { +DiagnosedSilenceableFailure mlir::test::TestProduceNullParamOp::apply( + transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { results.setParams(llvm::cast(getOut()), Attribute()); return DiagnosedSilenceableFailure::success(); } @@ -654,9 +685,9 @@ transform::producesHandle(getOut(), effects); } -DiagnosedSilenceableFailure -mlir::test::TestProduceNullValueOp::apply(transform::TransformResults &results, - transform::TransformState &state) { +DiagnosedSilenceableFailure mlir::test::TestProduceNullValueOp::apply( + transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { results.setValues(llvm::cast(getOut()), Value()); return DiagnosedSilenceableFailure::success(); } @@ -676,6 +707,7 @@ } DiagnosedSilenceableFailure mlir::test::TestRequiredMemoryEffectsOp::apply( + transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { results.set(llvm::cast(getOut()), state.getPayloadOps(getIn())); return DiagnosedSilenceableFailure::success(); @@ -694,10 +726,9 @@ } DiagnosedSilenceableFailure -mlir::test::TestTrackedRewriteOp::apply(transform::TransformResults &results, +mlir::test::TestTrackedRewriteOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { - transform::ErrorCheckingTrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); int64_t numIterations = 0; // `getPayloadOps` returns an iterator that skips ops that are erased in the diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td @@ -72,6 +72,7 @@ let cppNamespace = "::mlir::test"; let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::Operation *target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -89,6 +90,7 @@ let cppNamespace = "::mlir::test"; let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::Operation *target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -116,6 +118,7 @@ let cppNamespace = "::mlir::test"; let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::Operation *target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state) { @@ -260,6 +263,7 @@ let cppNamespace = "::mlir::test"; let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::Operation * target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -277,6 +281,7 @@ let cppNamespace = "::mlir::test"; let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::Operation * target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -295,6 +300,7 @@ let cppNamespace = "::mlir::test"; let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::Operation * target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -313,6 +319,7 @@ let cppNamespace = "::mlir::test"; let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::Operation * target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -329,6 +336,7 @@ let cppNamespace = "::mlir::test"; let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::Operation * target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -428,6 +436,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::Operation *target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -498,7 +507,8 @@ let cppNamespace = "::mlir::test"; let extraClassDeclaration = [{ - DiagnosedSilenceableFailure apply(transform::TransformResults &results, + DiagnosedSilenceableFailure apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { llvm_unreachable("op should not be used as a transform"); return DiagnosedSilenceableFailure::definiteFailure();