diff --git a/mlir/docs/Tutorials/transform/Ch1.md b/mlir/docs/Tutorials/transform/Ch1.md --- a/mlir/docs/Tutorials/transform/Ch1.md +++ b/mlir/docs/Tutorials/transform/Ch1.md @@ -362,3 +362,7 @@ Note that the “add” elementwise operation is indicated as payload ancestor because it was used to produce the tile loop, and the loop therefore has its location. Finally, we would like to replace the call to the outlined function with a call to the microkernel. Unfortunately, the Transform dialect doesn’t have support for this transformation (and cannot have if the call is rewritten to a custom, out-of-tree operation). Therefore, we need to define new transform operations. The next chapters will describe how this can be done. + +## Tracking IR Modifications + +The transform dialect automatically tracks all IR changes that are made as part of transform ops. (Implementations must use the provided rewriter to modify IR.) If a payload op is erased, it is automatically removed from all handles that it is currently associated with. If a payload op is replaced, the transform dialect tries to find the replacement op and updates all handles accordingly. If a multi-result op is replaced with values that are defined by multiple ops, or if an op is replaced with an op of a different type, an error is produced. This is because it is unclear whether the direct replacements actually represent the computation of the original op. There are ways to customize this behavior. More details can be found at the documentation of `transform::TrackingListener`. diff --git a/mlir/docs/Tutorials/transform/Ch2.md b/mlir/docs/Tutorials/transform/Ch2.md --- a/mlir/docs/Tutorials/transform/Ch2.md +++ b/mlir/docs/Tutorials/transform/Ch2.md @@ -189,7 +189,7 @@ } ``` -To finalize the definition of the transform operation, we need to implement the interface methods. The `TransformOpInterface` currently requires only one method – `apply` – that performs the actual transformation. It is a good practice to limit the body of the method to manipulation of the Transform dialect constructs and have the actual transformation implemented as a standalone function so it can be used from other places in the code. +To finalize the definition of the transform operation, we need to implement the interface methods. The `TransformOpInterface` currently requires only one method – `apply` – that performs the actual transformation. It is a good practice to limit the body of the method to manipulation of the Transform dialect constructs and have the actual transformation implemented as a standalone function so it can be used from other places in the code. Similar to rewrite patterns, all IR must be modified with the provided rewriter. ```c++ @@ -198,23 +198,25 @@ // Implementation of our transform dialect operation. // This operation returns a tri-state result that can be one of: // - success when the transformation succeeded; -// - definite failure when the transformation failed in such a way that following -// transformations are impossible or undesirable, typically it could have left payload -// IR in an invalid state; it is expected that a diagnostic is emitted immediately -// before returning the definite error; -// - silenceable failure when the transformation failed but following transformations -// are still applicable, typically this means a precondition for the transformation is -// not satisfied and the payload IR has not been modified. -// The silenceable failure additionally carries a Diagnostic that can be emitted to the -// user. -::mlir::DiagnosedSilenceableFailure ChangeCallTargetOp::apply( - // The list of payload IR entities that will be associated with the transform IR - // values defined by this transform operation. In this case, it can remain empty as - // there are no results. +// - definite failure when the transformation failed in such a way that +// following transformations are impossible or undesirable, typically it could +// have left payload IR in an invalid state; it is expected that a diagnostic +// is emitted immediately before returning the definite error; +// - silenceable failure when the transformation failed but following +// transformations are still applicable, typically this means a precondition +// for the transformation is not satisfied and the payload IR has not been +// modified. The silenceable failure additionally carries a Diagnostic that +// can be emitted to the user. +::mlir::DiagnosedSilenceableFailure mlir::transform::ChangeCallTargetOp::apply( + // The rewriter that should be used when modifying IR. + ::mlir::transform::TransformRewriter &rewriter, + // The list of payload IR entities that will be associated with the + // transform IR values defined by this transform operation. In this case, it + // can remain empty as there are no results. ::mlir::transform::TransformResults &results, - // The transform application state. This object can be used to query the current - // associations between transform IR values and payload IR entities. It can also - // carry additional user-defined state. + // The transform application state. This object can be used to query the + // current associations between transform IR values and payload IR entities. + // It can also carry additional user-defined state. ::mlir::transform::TransformState &state) { // First, we need to obtain the list of payload operations that are associated with diff --git a/mlir/docs/Tutorials/transform/Ch3.md b/mlir/docs/Tutorials/transform/Ch3.md --- a/mlir/docs/Tutorials/transform/Ch3.md +++ b/mlir/docs/Tutorials/transform/Ch3.md @@ -43,6 +43,7 @@ // Declare the function implementing the interface for a single payload operation. let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::func::CallOp call, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -54,7 +55,8 @@ ```c++ ::mlir::DiagnosedSilenceableFailure ChangeCallTargetOp::applyToOne( - ::mlir::func::CallOp call,, + ::mlir::transform::TransformRewriter &rewriter, + ::mlir::func::CallOp call, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state) { // Call the actual transformation function. @@ -176,6 +178,7 @@ // Declare the function implementing the interface for a single payload operation. let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::CallOpInterface call, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -189,6 +192,7 @@ // In MyExtension.cpp. ::mlir::DiagnosedSilenceableFailure CallToOp::applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::CallOpInterface call, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state) { @@ -228,6 +232,7 @@ ```c++ ::mlir::DiagnosedSilenceableFailure SomeOtherOp::apply( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::transform::TransformResults &results, ::mlir::transform::TransformState &state) { // ... @@ -273,6 +278,7 @@ // Declare the function implementing the interface for a single payload operation. let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::CallOpInterface call, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); diff --git a/mlir/examples/transform/Ch2/lib/MyExtension.cpp b/mlir/examples/transform/Ch2/lib/MyExtension.cpp --- a/mlir/examples/transform/Ch2/lib/MyExtension.cpp +++ b/mlir/examples/transform/Ch2/lib/MyExtension.cpp @@ -74,17 +74,17 @@ // This operation returns a tri-state result that can be one of: // - success when the transformation succeeded; // - definite failure when the transformation failed in such a way that -// following -// transformations are impossible or undesirable, typically it could have left -// payload IR in an invalid state; it is expected that a diagnostic is emitted -// immediately before returning the definite error; +// following transformations are impossible or undesirable, typically it could +// have left payload IR in an invalid state; it is expected that a diagnostic +// is emitted immediately before returning the definite error; // - silenceable failure when the transformation failed but following -// transformations -// are still applicable, typically this means a precondition for the -// transformation is not satisfied and the payload IR has not been modified. -// The silenceable failure additionally carries a Diagnostic that can be emitted -// to the user. +// transformations are still applicable, typically this means a precondition +// for the transformation is not satisfied and the payload IR has not been +// modified. The silenceable failure additionally carries a Diagnostic that +// can be emitted to the user. ::mlir::DiagnosedSilenceableFailure mlir::transform::ChangeCallTargetOp::apply( + // The rewriter that should be used when modifying IR. + ::mlir::transform::TransformRewriter &rewriter, // The list of payload IR entities that will be associated with the // transform IR values defined by this transform operation. In this case, it // can remain empty as there are no results. diff --git a/mlir/examples/transform/Ch3/include/MyExtension.td b/mlir/examples/transform/Ch3/include/MyExtension.td --- a/mlir/examples/transform/Ch3/include/MyExtension.td +++ b/mlir/examples/transform/Ch3/include/MyExtension.td @@ -60,6 +60,7 @@ // Declare the function implementing the interface for a single payload operation. let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::func::CallOp call, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -89,6 +90,7 @@ // Declare the function implementing the interface for a single payload operation. let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::CallOpInterface call, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); diff --git a/mlir/examples/transform/Ch3/lib/MyExtension.cpp b/mlir/examples/transform/Ch3/lib/MyExtension.cpp --- a/mlir/examples/transform/Ch3/lib/MyExtension.cpp +++ b/mlir/examples/transform/Ch3/lib/MyExtension.cpp @@ -113,6 +113,8 @@ // to the user. ::mlir::DiagnosedSilenceableFailure mlir::transform::ChangeCallTargetOp::applyToOne( + // The rewriter that should be used when modifying IR. + ::mlir::transform::TransformRewriter &rewriter, // The single payload operation to which the transformation is applied. ::mlir::func::CallOp call, // The payload IR entities that will be appended to lists associated with @@ -146,27 +148,27 @@ // CallToOp //===---------------------------------------------------------------------===// -static mlir::Operation *replaceCallWithOp(mlir::CallOpInterface call) { +static mlir::Operation *replaceCallWithOp(mlir::RewriterBase &rewriter, + mlir::CallOpInterface call) { // Construct an operation from an unregistered dialect. This is discouraged // and is only used here for brevity of the overall example. mlir::OperationState state(call.getLoc(), "my.mm4"); state.types.assign(call->result_type_begin(), call->result_type_end()); state.operands.assign(call->operand_begin(), call->operand_end()); - mlir::OpBuilder builder(call); - mlir::Operation *replacement = builder.create(state); - call->replaceAllUsesWith(replacement->getResults()); - call->erase(); + mlir::Operation *replacement = rewriter.create(state); + rewriter.replaceOp(call, replacement->getResults()); return replacement; } // See above for the signature description. mlir::DiagnosedSilenceableFailure mlir::transform::CallToOp::applyToOne( - mlir::CallOpInterface call, mlir::transform::ApplyToEachResultList &results, + mlir::transform::TransformRewriter &rewriter, mlir::CallOpInterface call, + mlir::transform::ApplyToEachResultList &results, mlir::transform::TransformState &state) { // Dispatch to the actual transformation. - Operation *replacement = replaceCallWithOp(call); + Operation *replacement = replaceCallWithOp(rewriter, call); // Associate the payload operation produced by the rewrite with the result // handle of this transform operation. diff --git a/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td b/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td --- a/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td +++ b/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td @@ -152,6 +152,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::tensor::EmptyOp target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); diff --git a/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td b/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td --- a/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td +++ b/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td @@ -126,6 +126,7 @@ }]; let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::Operation *target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -189,6 +190,7 @@ }]; let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::Operation *target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -155,6 +155,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::linalg::LinalgOp target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -270,6 +271,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::linalg::LinalgOp target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -313,6 +315,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::linalg::GenericOp target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -350,6 +353,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::tensor::PackOp target, ::mlir::transform::ApplyToEachResultList &transformResults, ::mlir::transform::TransformState &state); @@ -389,6 +393,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::tensor::UnPackOp target, ::mlir::transform::ApplyToEachResultList &transformResults, ::mlir::transform::TransformState &state); @@ -531,6 +536,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::linalg::LinalgOp target, ::mlir::transform::ApplyToEachResultList &results, TransformState &state); @@ -837,6 +843,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::linalg::LinalgOp target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -931,6 +938,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::tensor::PadOp, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -978,6 +986,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::linalg::LinalgOp target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -1051,6 +1060,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::linalg::LinalgOp target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -1098,6 +1108,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::Operation *target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -1309,6 +1320,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::linalg::LinalgOp target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -1413,6 +1425,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::linalg::LinalgOp target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -1522,6 +1535,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::linalg::LinalgOp target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -1726,6 +1740,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure apply( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::transform::TransformResults &transformResults, ::mlir::transform::TransformState &state); @@ -1862,6 +1877,7 @@ ]; let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::Operation *target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -1914,6 +1930,7 @@ let extraClassDeclaration = [{ // TODO: applyToOne. ::mlir::DiagnosedSilenceableFailure apply( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::transform::TransformResults &transformResults, ::mlir::transform::TransformState &state); @@ -1958,7 +1975,8 @@ ]; let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( - ::mlir::func::FuncOp target, + ::mlir::transform::TransformRewriter &rewriter, + ::mlir::func::FuncOp target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); }]; @@ -2034,6 +2052,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::linalg::LinalgOp target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -2082,6 +2101,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::Operation *target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -2120,6 +2140,7 @@ ]; let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::Operation *target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); diff --git a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td --- a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td +++ b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td @@ -158,6 +158,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::Operation *target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td --- a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td +++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td @@ -126,6 +126,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::scf::ForOp target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -168,6 +169,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::scf::ForOp target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -201,6 +203,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::Operation *target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -228,6 +231,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::Operation *target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -269,6 +273,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::scf::IfOp ifOp, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); diff --git a/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td b/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td --- a/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td +++ b/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td @@ -138,6 +138,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::Operation *target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); diff --git a/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h --- a/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h +++ b/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h @@ -50,7 +50,8 @@ return success(); } - DiagnosedSilenceableFailure apply(TransformResults &results, + DiagnosedSilenceableFailure apply(TransformRewriter &rewriter, + TransformResults &results, TransformState &state) { Value operandHandle = cast(this->getOperation()).getOperandHandle(); auto payload = state.getPayloadOps(operandHandle); @@ -90,7 +91,8 @@ return success(); } - DiagnosedSilenceableFailure apply(TransformResults &results, + DiagnosedSilenceableFailure apply(TransformRewriter &rewriter, + TransformResults &results, TransformState &state) { Value operandHandle = cast(this->getOperation()).getOperandHandle(); ValueRange payload = state.getPayloadValues(operandHandle); diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h @@ -18,11 +18,11 @@ #include "mlir/Support/LogicalResult.h" namespace mlir { - namespace transform { class TransformOpInterface; class TransformResults; +class TransformRewriter; class TransformState; using Param = Attribute; @@ -856,8 +856,7 @@ public TransformState::Extension { public: /// Create a new TrackingListener for usage in the specified transform op. - explicit TrackingListener(TransformState &state, TransformOpInterface op) - : TransformState::Extension(state), transformOp(op) {} + TrackingListener(TransformState &state, TransformOpInterface op); protected: /// Return a replacement payload op for the given op, which is going to be @@ -937,6 +936,9 @@ /// The transform op in which this TrackingListener is used. TransformOpInterface transformOp; + + /// The handles that are consumed by the transform op. + DenseSet consumedHandles; }; /// A specialized listener that keeps track of cases in which no replacement @@ -968,6 +970,17 @@ int64_t errorCounter = 0; }; +/// This is a special rewriter to be used in transform op implementations, +/// providing additional helper functions to update the transform state, etc. +// TODO: Helper functions will be added in a subsequent change. +class TransformRewriter : public RewriterBase { +protected: + friend class TransformState; + + /// Create a new TransformRewriter. + explicit TransformRewriter(MLIRContext *ctx); +}; + /// This trait is supposed to be attached to Transform dialect operations that /// can be standalone top-level transforms. Such operations typically contain /// other Transform dialect operations that can be executed following some @@ -1064,7 +1077,8 @@ /// 5. If any `applyToOne` return silenceableFailure, the transformation is /// considered silenceable. /// 6. Otherwise the transformation is considered successful. - DiagnosedSilenceableFailure apply(TransformResults &transformResults, + DiagnosedSilenceableFailure apply(transform::TransformRewriter &rewriter, + TransformResults &transformResults, TransformState &state); /// Checks that the op matches the expectations of this trait. @@ -1307,14 +1321,14 @@ /// `targets` contains operations of the same class and a silenceable failure /// is reported if it does not. template -DiagnosedSilenceableFailure -applyTransformToEach(TransformOpTy transformOp, Range &&targets, - SmallVectorImpl &results, - TransformState &state) { +DiagnosedSilenceableFailure applyTransformToEach( + TransformOpTy transformOp, TransformRewriter &rewriter, Range &&targets, + SmallVectorImpl &results, TransformState &state) { using OpTy = typename llvm::function_traits< - decltype(&TransformOpTy::applyToOne)>::template arg_t<0>; + decltype(&TransformOpTy::applyToOne)>::template arg_t<1>; static_assert(std::is_convertible::value, "expected transform function to take an operation"); + OpBuilder::InsertionGuard g(rewriter); SmallVector silenceableStack; unsigned expectedNumResults = transformOp->getNumResults(); @@ -1331,8 +1345,9 @@ ApplyToEachResultList partialResults; partialResults.reserve(expectedNumResults); Location specificOpLoc = specificOp->getLoc(); + rewriter.setInsertionPoint(specificOp); DiagnosedSilenceableFailure res = - transformOp.applyToOne(specificOp, partialResults, state); + transformOp.applyToOne(rewriter, specificOp, partialResults, state); if (res.isDefiniteFailure()) return DiagnosedSilenceableFailure::definiteFailure(); @@ -1367,7 +1382,8 @@ template mlir::DiagnosedSilenceableFailure mlir::transform::TransformEachOpTrait::apply( - TransformResults &transformResults, TransformState &state) { + TransformRewriter &rewriter, TransformResults &transformResults, + TransformState &state) { Value handle = this->getOperation()->getOperand(0); auto targets = state.getPayloadOps(handle); @@ -1403,7 +1419,7 @@ // corresponding results entry. SmallVector results; DiagnosedSilenceableFailure result = detail::applyTransformToEach( - cast(this->getOperation()), targets, results, state); + cast(this->getOperation()), rewriter, targets, results, state); // Step 3. Propagate the definite failure if any and bail out. if (result.isDefiniteFailure()) diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td @@ -41,10 +41,15 @@ transformation represented by the current op is targeted. Returns a special status object indicating whether the transformation succeeded or failed, and, if it failed, whether the failure is recoverable or not. + + IR must be created, modified and deleted with the provided rewriter. + implementations are responsible for setting the insertion point of the + rewriter to the desired location. }], /*returnType=*/"::mlir::DiagnosedSilenceableFailure", /*name=*/"apply", /*arguments=*/(ins + "::mlir::transform::TransformRewriter &":$rewriter, "::mlir::transform::TransformResults &":$transformResults, "::mlir::transform::TransformState &":$state )>, diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td @@ -177,6 +177,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::Operation *target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -204,6 +205,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::Operation *target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); diff --git a/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp b/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp --- a/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp +++ b/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp @@ -63,7 +63,8 @@ } // namespace DiagnosedSilenceableFailure -SimplifyBoundedAffineOpsOp::apply(TransformResults &results, +SimplifyBoundedAffineOpsOp::apply(transform::TransformRewriter &rewriter, + TransformResults &results, TransformState &state) { // Get constraints for bounded values. SmallVector lbs; @@ -127,6 +128,8 @@ SimplifyAffineMinMaxOp>(getContext(), cstr); FrozenRewritePatternSet frozenPatterns(std::move(patterns)); GreedyRewriteConfig config; + config.listener = + static_cast(rewriter.getListener()); config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps; // Apply the simplification pattern to a fixpoint. if (failed(applyOpPatternsAndFold(targets, frozenPatterns, config))) { diff --git a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp --- a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp +++ b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp @@ -26,7 +26,8 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::OneShotBufferizeOp::apply(TransformResults &transformResults, +transform::OneShotBufferizeOp::apply(transform::TransformRewriter &rewriter, + TransformResults &transformResults, TransformState &state) { OneShotBufferizationOptions options; options.allowReturnAllocs = getAllowReturnAllocs(); @@ -71,10 +72,9 @@ modifiesPayload(effects); } -DiagnosedSilenceableFailure -transform::EliminateEmptyTensorsOp::apply(TransformResults &transformResults, - TransformState &state) { - IRRewriter rewriter(getContext()); +DiagnosedSilenceableFailure transform::EliminateEmptyTensorsOp::apply( + transform::TransformRewriter &rewriter, TransformResults &transformResults, + TransformState &state) { OneShotBufferizationOptions options; options.allowReturnAllocs = true; @@ -95,11 +95,9 @@ // EmptyTensorToAllocTensorOp //===----------------------------------------------------------------------===// -DiagnosedSilenceableFailure -EmptyTensorToAllocTensorOp::applyToOne(tensor::EmptyOp target, - ApplyToEachResultList &results, - transform::TransformState &state) { - IRRewriter rewriter(target->getContext()); +DiagnosedSilenceableFailure EmptyTensorToAllocTensorOp::applyToOne( + transform::TransformRewriter &rewriter, tensor::EmptyOp target, + ApplyToEachResultList &results, transform::TransformState &state) { rewriter.setInsertionPoint(target); auto alloc = rewriter.replaceOpWithNewOp( target, target.getType(), target.getDynamicSizes()); diff --git a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp --- a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp +++ b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp @@ -398,7 +398,7 @@ /// Alter kernel configuration of the given kernel. static DiagnosedSilenceableFailure -alterGpuLaunch(IRRewriter &rewriter, LaunchOp gpuLaunch, +alterGpuLaunch(RewriterBase &rewriter, LaunchOp gpuLaunch, TransformOpInterface transformOp, std::optional gridDimX = std::nullopt, std::optional gridDimY = std::nullopt, @@ -661,12 +661,10 @@ return DiagnosedSilenceableFailure::success(); } -DiagnosedSilenceableFailure -transform::MapForallToBlocks::applyToOne(Operation *target, - ApplyToEachResultList &results, - transform::TransformState &state) { +DiagnosedSilenceableFailure transform::MapForallToBlocks::applyToOne( + transform::TransformRewriter &rewriter, Operation *target, + ApplyToEachResultList &results, transform::TransformState &state) { LaunchOp gpuLaunch = dyn_cast(target); - IRRewriter rewriter(getContext()); auto transformOp = cast(getOperation()); if (!getGenerateGpuLaunch() && !gpuLaunch) { @@ -856,7 +854,8 @@ } DiagnosedSilenceableFailure transform::MapNestedForallToThreads::applyToOne( - Operation *target, ApplyToEachResultList &results, TransformState &state) { + transform::TransformRewriter &rewriter, Operation *target, + ApplyToEachResultList &results, TransformState &state) { LaunchOp gpuLaunch = dyn_cast(target); auto transformOp = cast(getOperation()); @@ -877,7 +876,6 @@ // Set the GPU launch configuration for the block dims early, this is not // subject to IR inspection. - IRRewriter rewriter(getContext()); diag = alterGpuLaunch(rewriter, gpuLaunch, transformOp, std::nullopt, std::nullopt, std::nullopt, blockDims[0], blockDims[1], blockDims[2]); diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -169,13 +169,11 @@ // BufferizeToAllocationOp //===----------------------------------------------------------------------===// -DiagnosedSilenceableFailure -transform::BufferizeToAllocationOp::apply(transform::TransformResults &results, - transform::TransformState &state) { +DiagnosedSilenceableFailure transform::BufferizeToAllocationOp::apply( + transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { Attribute memorySpace = getMemorySpace().has_value() ? getMemorySpace().value() : Attribute(); - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); auto transformed = llvm::to_vector( llvm::map_range(state.getPayloadValues(getTarget()), [&](Value v) { return linalg::bufferizeToAllocation(rewriter, v, memorySpace); @@ -196,7 +194,8 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::DecomposeOp::applyToOne(LinalgOp target, +transform::DecomposeOp::applyToOne(transform::TransformRewriter &rewriter, + LinalgOp target, transform::ApplyToEachResultList &results, transform::TransformState &state) { #define DOWNSCALE(trans) \ @@ -286,7 +285,8 @@ } DiagnosedSilenceableFailure -transform::FuseOp::apply(mlir::transform::TransformResults &transformResults, +transform::FuseOp::apply(transform::TransformRewriter &rewriter, + mlir::transform::TransformResults &transformResults, mlir::transform::TransformState &state) { SmallVector tileSizes = extractFromI64ArrayAttr(getTileSizes()); SmallVector tileInterchange = @@ -297,8 +297,6 @@ tilingOptions = tilingOptions.setTileSizes(tileSizes); scf::SCFTileAndFuseOptions tileAndFuseOptions; tileAndFuseOptions.tilingOptions = tilingOptions; - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); LogicalResult result = applyTilingToAll( rewriter, getOperation(), state.getPayloadOps(getTarget()), tileSizes.size() - llvm::count(tileSizes, 0), transformResults, @@ -721,7 +719,8 @@ } DiagnosedSilenceableFailure -transform::FuseIntoContainingOp::apply(transform::TransformResults &results, +transform::FuseIntoContainingOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { SmallVector fusedOps; auto producerOps = state.getPayloadOps(getProducerOp()); @@ -764,8 +763,6 @@ return failure(); }; - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); while (!remainingProducers.empty()) { auto nextProducer = getNextProducer(); if (failed(nextProducer)) { @@ -842,7 +839,8 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::GeneralizeOp::applyToOne(LinalgOp target, +transform::GeneralizeOp::applyToOne(transform::TransformRewriter &rewriter, + LinalgOp target, transform::ApplyToEachResultList &results, transform::TransformState &state) { // Exit early if no transformation is needed. @@ -850,8 +848,6 @@ results.push_back(target); return DiagnosedSilenceableFailure::success(); } - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); rewriter.setInsertionPoint(target); FailureOr generic = generalizeNamedOp(rewriter, target); if (succeeded(generic)) { @@ -866,7 +862,8 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::InterchangeOp::applyToOne(GenericOp target, +transform::InterchangeOp::applyToOne(transform::TransformRewriter &rewriter, + GenericOp target, transform::ApplyToEachResultList &results, transform::TransformState &state) { ArrayRef interchangeVector = getIteratorInterchange(); @@ -875,8 +872,6 @@ results.push_back(target); return DiagnosedSilenceableFailure::success(); } - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); FailureOr res = interchangeGenericOp(rewriter, target, SmallVector(interchangeVector.begin(), @@ -904,10 +899,9 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::LowerPackOp::applyToOne( - tensor::PackOp target, transform::ApplyToEachResultList &transformResults, + transform::TransformRewriter &rewriter, tensor::PackOp target, + transform::ApplyToEachResultList &transformResults, transform::TransformState &state) { - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); rewriter.setInsertionPoint(target); FailureOr res = lowerPack(rewriter, target); if (failed(res)) { @@ -925,10 +919,9 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::LowerUnPackOp::applyToOne( - tensor::UnPackOp target, transform::ApplyToEachResultList &transformResults, + transform::TransformRewriter &rewriter, tensor::UnPackOp target, + transform::ApplyToEachResultList &transformResults, transform::TransformState &state) { - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); rewriter.setInsertionPoint(target); FailureOr res = lowerUnPack(rewriter, target); if (failed(res)) { @@ -964,7 +957,8 @@ } DiagnosedSilenceableFailure -transform::MatchOp::apply(transform::TransformResults &results, +transform::MatchOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { llvm::StringSet<> strs; if (getOps().has_value()) @@ -1053,8 +1047,8 @@ } DiagnosedSilenceableFailure transform::MultiTileSizesOp::applyToOne( - LinalgOp target, transform::ApplyToEachResultList &results, - TransformState &state) { + transform::TransformRewriter &rewriter, LinalgOp target, + transform::ApplyToEachResultList &results, TransformState &state) { if (isa(getLowSize().getType())) { if (target.hasDynamicShape()) { auto diag = emitSilenceableError() @@ -1155,7 +1149,8 @@ } DiagnosedSilenceableFailure -transform::PackOp::apply(transform::TransformResults &transformResults, +transform::PackOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &transformResults, transform::TransformState &state) { auto targetOps = state.getPayloadOps(getTarget()); // If nothing to pack, propagate success. @@ -1184,8 +1179,6 @@ DiagnosedSilenceableFailure status = unpackSingleIndexResultPayloadOperations( state, *this, packedSizes, getMixedPackedSizes()); - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); rewriter.setInsertionPoint(linalgOp); FailureOr maybeResult = pack(rewriter, linalgOp, packedSizes); if (failed(maybeResult)) @@ -1364,11 +1357,10 @@ } DiagnosedSilenceableFailure -PackGreedilyOp::apply(transform::TransformResults &transformResults, +PackGreedilyOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &transformResults, transform::TransformState &state) { SmallVector results; - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); for (Operation *op : state.getPayloadOps(getTarget())) { auto linalgOp = dyn_cast(op); if (!linalgOp) @@ -1464,7 +1456,8 @@ } DiagnosedSilenceableFailure -transform::PackTransposeOp::apply(transform::TransformResults &transformResults, +transform::PackTransposeOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &transformResults, transform::TransformState &state) { auto packOrUnpackOps = state.getPayloadOps(getTargetPackOrUnPackOp()); auto linalgOps = state.getPayloadOps(getTargetLinalgOp()); @@ -1542,8 +1535,6 @@ assert(packOp && linalgOp && "unexpected null op"); // Step 3. Actually transpose the ops. - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); FailureOr res = packTranspose( rewriter, packOp, linalgOp, unPackOp, getOuterPerm(), getInnerPerm()); // Preconditions have been checked, it is an error to fail here. @@ -1568,7 +1559,8 @@ //===---------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::PadOp::applyToOne(LinalgOp target, +transform::PadOp::applyToOne(transform::TransformRewriter &rewriter, + LinalgOp target, transform::ApplyToEachResultList &results, transform::TransformState &state) { // Convert the integer packing flags to booleans. @@ -1616,8 +1608,6 @@ transposePaddings.push_back( extractFromI64ArrayAttr(cast(transposeVector))); - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); LinalgOp paddedOp; SmallVector paddingDimensions = extractFromI64ArrayAttr(getPaddingDimensions()); @@ -1684,6 +1674,7 @@ //===---------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::HoistPadBuildPackingLoopNestOp::apply( + transform::TransformRewriter &rewriter, transform::TransformResults &transformResults, transform::TransformState &state) { auto targetOps = state.getPayloadOps(getTarget()); @@ -1700,8 +1691,6 @@ if (!padOp || !loopOp) return emitDefiniteFailure() << "requires exactly 2 non-null handles"; - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); FailureOr result = linalg::detail::buildPackingLoopNest(rewriter, padOp, loopOp, getTranspose()); @@ -1740,13 +1729,12 @@ } DiagnosedSilenceableFailure -transform::HoistPadOp::applyToOne(tensor::PadOp target, +transform::HoistPadOp::applyToOne(transform::TransformRewriter &rewriter, + tensor::PadOp target, transform::ApplyToEachResultList &results, transform::TransformState &state) { tensor::PadOp hoistedPadOp; SmallVector transposeOps; - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); FailureOr result = hoistPaddingOnTensors(rewriter, target, getNumLoops(), getTranspose(), hoistedPadOp, transposeOps); @@ -1779,7 +1767,8 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::PromoteOp::applyToOne(LinalgOp target, +transform::PromoteOp::applyToOne(transform::TransformRewriter &rewriter, + LinalgOp target, transform::ApplyToEachResultList &results, transform::TransformState &state) { LinalgPromotionOptions promotionOptions; @@ -1829,8 +1818,6 @@ if (failed(promoteSubviewsPrecondition(target, promotionOptions))) return emitDefaultDefiniteFailure(target); - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); rewriter.setInsertionPoint(target); FailureOr res = promoteSubViews(rewriter, target, promotionOptions); if (failed(res)) @@ -1844,7 +1831,8 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::ReplaceOp::apply(TransformResults &transformResults, +transform::ReplaceOp::apply(transform::TransformRewriter &rewriter, + TransformResults &transformResults, TransformState &state) { auto payload = state.getPayloadOps(getTarget()); @@ -1859,8 +1847,6 @@ } // Clone and replace. - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); Operation *pattern = &getBodyRegion().front().front(); SmallVector replacements; for (Operation *target : payload) { @@ -1904,7 +1890,8 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::ScalarizeOp::applyToOne(LinalgOp target, +transform::ScalarizeOp::applyToOne(transform::TransformRewriter &rewriter, + LinalgOp target, transform::ApplyToEachResultList &results, transform::TransformState &state) { scf::SCFTilingOptions tilingOptions; @@ -1916,8 +1903,6 @@ AffineMap map = target.getShapesToLoopsMap(); if (!map) return tileSizes; - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); SmallVector shapeSizes = affine::makeComposedFoldedMultiResultAffineApply(rewriter, loc, map, allShapeSizes); @@ -1931,8 +1916,6 @@ return tileSizes; }); SmallVector emptyTileSizes; - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); rewriter.setInsertionPoint(target); FailureOr maybeTilingResult = tileUsingSCFForOp( rewriter, cast(target.getOperation()), tilingOptions); @@ -1956,11 +1939,10 @@ DiagnosedSilenceableFailure transform::RewriteInDestinationPassingStyleOp::applyToOne( - Operation *target, transform::ApplyToEachResultList &results, + transform::TransformRewriter &rewriter, Operation *target, + transform::ApplyToEachResultList &results, transform::TransformState &state) { SmallVector res; - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); rewriter.setInsertionPoint(target); FailureOr maybeResult = TypeSwitch>(target) @@ -1978,13 +1960,12 @@ // SplitOp //===----------------------------------------------------------------------===// -DiagnosedSilenceableFailure SplitOp::apply(TransformResults &results, - TransformState &state) { +DiagnosedSilenceableFailure +SplitOp::apply(transform::TransformRewriter &rewriter, + TransformResults &results, TransformState &state) { // Collect the dynamic split points if provided. SmallVector payload = llvm::to_vector(state.getPayloadOps(getTarget())); - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); SmallVector splitPoints; splitPoints.reserve(payload.size()); if (getDynamicSplitPoint()) { @@ -2184,15 +2165,14 @@ } DiagnosedSilenceableFailure transform::SplitReductionOp::applyToOne( - LinalgOp target, transform::ApplyToEachResultList &results, + transform::TransformRewriter &rewriter, LinalgOp target, + transform::ApplyToEachResultList &results, transform::TransformState &state) { ControlSplitReductionFn splitFn = [&](LinalgOp) { return linalg::SplitReductionOptions{int64_t(getSplitFactor()), unsigned(getInsertSplitDimension()), bool(getInnerParallel())}; }; - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); rewriter.setInsertionPoint(target); FailureOr splitResult = (getUseScalingAlgorithm()) @@ -2230,10 +2210,9 @@ } DiagnosedSilenceableFailure transform::TileReductionUsingScfOp::applyToOne( - LinalgOp target, transform::ApplyToEachResultList &results, + transform::TransformRewriter &rewriter, LinalgOp target, + transform::ApplyToEachResultList &results, transform::TransformState &state) { - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); rewriter.setInsertionPoint(target); FailureOr result = scf::tileReductionUsingScf( rewriter, cast(target.getOperation()), @@ -2274,10 +2253,9 @@ } DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne( - LinalgOp target, transform::ApplyToEachResultList &results, + transform::TransformRewriter &rewriter, LinalgOp target, + transform::ApplyToEachResultList &results, transform::TransformState &state) { - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); rewriter.setInsertionPoint(target); SmallVector numThreads = getAsOpFoldResult(rewriter.getI64ArrayAttr(getNumThreads())); @@ -2363,7 +2341,8 @@ } DiagnosedSilenceableFailure -transform::TileOp::apply(TransformResults &transformResults, +transform::TileOp::apply(transform::TransformRewriter &rewriter, + TransformResults &transformResults, TransformState &state) { ArrayRef tileSizes = getStaticSizes(); @@ -2478,8 +2457,6 @@ } tilingOptions.setInterchange(getInterchange()); - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); FailureOr maybeTilingResult = tileUsingSCFForOp(rewriter, tilingInterface, tilingOptions); if (failed(maybeTilingResult)) @@ -2720,45 +2697,44 @@ } DiagnosedSilenceableFailure -transform::TileToForallOp::apply(transform::TransformResults &transformResults, +transform::TileToForallOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &transformResults, transform::TransformState &state) { - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); - auto transformOp = cast(getOperation()); - - // Result payload ops. - SmallVector tileOps; - SmallVector tiledOps; - - // Unpack handles. - SmallVector mixedNumThreads; - DiagnosedSilenceableFailure status = - getPackedNumThreads() - ? unpackSingleIndexResultPayloadOperations( - state, transformOp, mixedNumThreads, getPackedNumThreads()) - : unpackSingleIndexResultPayloadOperations( - state, transformOp, mixedNumThreads, getMixedNumThreads()); - if (!status.succeeded()) - return status; - SmallVector mixedTileSizes; - status = getPackedTileSizes() - ? unpackSingleIndexResultPayloadOperations( - state, transformOp, mixedTileSizes, getPackedTileSizes()) - : unpackSingleIndexResultPayloadOperations( - state, transformOp, mixedTileSizes, getMixedTileSizes()); - if (!status.succeeded()) - return status; - - for (Operation *target : state.getPayloadOps(getTarget())) { - linalg::ForallTilingResult tilingResult; - DiagnosedSilenceableFailure diag = tileToForallOpImpl( - rewriter, state, transformOp, target, mixedNumThreads, mixedTileSizes, - getMapping(), tilingResult); - if (!diag.succeeded()) + auto transformOp = cast(getOperation()); + + // Result payload ops. + SmallVector tileOps; + SmallVector tiledOps; + + // Unpack handles. + SmallVector mixedNumThreads; + DiagnosedSilenceableFailure status = + getPackedNumThreads() + ? unpackSingleIndexResultPayloadOperations( + state, transformOp, mixedNumThreads, getPackedNumThreads()) + : unpackSingleIndexResultPayloadOperations( + state, transformOp, mixedNumThreads, getMixedNumThreads()); + if (!status.succeeded()) + return status; + SmallVector mixedTileSizes; + status = getPackedTileSizes() + ? unpackSingleIndexResultPayloadOperations( + state, transformOp, mixedTileSizes, getPackedTileSizes()) + : unpackSingleIndexResultPayloadOperations( + state, transformOp, mixedTileSizes, getMixedTileSizes()); + if (!status.succeeded()) + return status; + + for (Operation *target : state.getPayloadOps(getTarget())) { + linalg::ForallTilingResult tilingResult; + DiagnosedSilenceableFailure diag = tileToForallOpImpl( + rewriter, state, transformOp, target, mixedNumThreads, mixedTileSizes, + getMapping(), tilingResult); + if (!diag.succeeded()) return diag; tileOps.push_back(tilingResult.tileOp); tiledOps.push_back(tilingResult.tiledOp); - } + } transformResults.set(cast(getForallOp()), tileOps); transformResults.set(cast(getTiledOp()), tiledOps); @@ -2833,7 +2809,8 @@ } DiagnosedSilenceableFailure -transform::TileToScfForOp::apply(TransformResults &transformResults, +transform::TileToScfForOp::apply(transform::TransformRewriter &rewriter, + TransformResults &transformResults, TransformState &state) { ArrayRef tileSizes = getStaticSizes(); @@ -2902,8 +2879,6 @@ } tilingOptions.setInterchange(getInterchange()); - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); FailureOr tilingResult = tileUsingSCFForOp(rewriter, tilingInterfaceOp, tilingOptions); if (failed(tilingResult)) @@ -3047,7 +3022,8 @@ } // namespace DiagnosedSilenceableFailure -transform::VectorizeOp::applyToOne(Operation *target, +transform::VectorizeOp::applyToOne(transform::TransformRewriter &rewriter, + Operation *target, transform::ApplyToEachResultList &results, transform::TransformState &state) { if (!target->hasTrait()) { @@ -3092,10 +3068,9 @@ // MaskedVectorizeOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::MaskedVectorizeOp::apply( + transform::TransformRewriter &rewriter, mlir::transform::TransformResults &transformResults, mlir::transform::TransformState &state) { - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); auto targets = state.getPayloadOps(getTarget()); if (std::empty(targets)) return DiagnosedSilenceableFailure::success(); @@ -3173,7 +3148,8 @@ DiagnosedSilenceableFailure transform::HoistRedundantVectorTransfersOp::applyToOne( - func::FuncOp target, transform::ApplyToEachResultList &results, + transform::TransformRewriter &rewriter, func::FuncOp target, + transform::ApplyToEachResultList &results, transform::TransformState &state) { // WARNING: This hoisting does not model parallelism and is generally // incorrect when used on distributed loops with memref semantics! @@ -3188,10 +3164,9 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::ConvertConv2DToImg2ColOp::applyToOne( - linalg::LinalgOp target, transform::ApplyToEachResultList &results, + transform::TransformRewriter &rewriter, linalg::LinalgOp target, + transform::ApplyToEachResultList &results, transform::TransformState &state) { - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); rewriter.setInsertionPoint(target); auto maybeTransformed = TypeSwitch>>( @@ -3223,10 +3198,9 @@ DiagnosedSilenceableFailure transform::HoistRedundantTensorSubsetsOp::applyToOne( - Operation *target, transform::ApplyToEachResultList &results, + transform::TransformRewriter &rewriter, Operation *target, + transform::ApplyToEachResultList &results, transform::TransformState &state) { - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); auto forOp = dyn_cast(target); if (forOp) { linalg::hoistRedundantSubsetExtractInsert(rewriter, forOp); @@ -3289,11 +3263,10 @@ } DiagnosedSilenceableFailure transform::InsertSliceToCopyOp::applyToOne( - Operation *targetOp, transform::ApplyToEachResultList &results, + transform::TransformRewriter &rewriter, Operation *targetOp, + transform::ApplyToEachResultList &results, transform::TransformState &state) { - TrackingListener listener(state, *this); - IRRewriter rewriter(targetOp->getContext(), &listener); rewriter.setInsertionPoint(targetOp); if (auto target = dyn_cast(targetOp)) return doit(rewriter, target, results, state); diff --git a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp --- a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp +++ b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp @@ -60,10 +60,10 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::MemRefMultiBufferOp::apply( + transform::TransformRewriter &rewriter, transform::TransformResults &transformResults, transform::TransformState &state) { SmallVector results; - IRRewriter rewriter(getContext()); for (Operation *op : state.getPayloadOps(getTarget())) { bool canApplyMultiBuffer = true; auto target = cast(op); @@ -105,7 +105,8 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::MemRefMakeLoopIndependentOp::applyToOne( - Operation *target, transform::ApplyToEachResultList &results, + transform::TransformRewriter &rewriter, Operation *target, + transform::ApplyToEachResultList &results, transform::TransformState &state) { // Gather IVs. SmallVector ivs; @@ -123,7 +124,6 @@ } // Rewrite IR. - IRRewriter rewriter(target->getContext()); FailureOr replacement = failure(); if (auto allocaOp = dyn_cast(target)) { replacement = memref::replaceWithIndependentOp(rewriter, allocaOp, ivs); diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp --- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp +++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp @@ -35,7 +35,8 @@ // GetParentForOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::GetParentForOp::apply(transform::TransformResults &results, +transform::GetParentForOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { SetVector parents; for (Operation *target : state.getPayloadOps(getTarget())) { @@ -92,7 +93,8 @@ } DiagnosedSilenceableFailure -transform::LoopOutlineOp::apply(transform::TransformResults &results, +transform::LoopOutlineOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { SmallVector functions; SmallVector calls; @@ -100,7 +102,6 @@ for (Operation *target : state.getPayloadOps(getTarget())) { Location location = target->getLoc(); Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(target); - IRRewriter rewriter(getContext()); scf::ExecuteRegionOp exec = wrapInExecuteRegion(rewriter, target); if (!exec) { DiagnosedSilenceableFailure diag = emitSilenceableError() @@ -135,11 +136,11 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::LoopPeelOp::applyToOne(scf::ForOp target, +transform::LoopPeelOp::applyToOne(transform::TransformRewriter &rewriter, + scf::ForOp target, transform::ApplyToEachResultList &results, transform::TransformState &state) { scf::ForOp result; - IRRewriter rewriter(target->getContext()); // This helper returns failure when peeling does not occur (i.e. when the IR // is not modified). This is not a failure for the op as the postcondition: // "the loop trip count is divisible by the step" @@ -192,7 +193,8 @@ } DiagnosedSilenceableFailure -transform::LoopPipelineOp::applyToOne(scf::ForOp target, +transform::LoopPipelineOp::applyToOne(transform::TransformRewriter &rewriter, + scf::ForOp target, transform::ApplyToEachResultList &results, transform::TransformState &state) { scf::PipeliningOption options; @@ -203,7 +205,6 @@ getReadLatency()); }; scf::ForLoopPipeliningPattern pattern(options, target->getContext()); - IRRewriter rewriter(getContext()); rewriter.setInsertionPoint(target); FailureOr patternResult = scf::pipelineForLoop(rewriter, target, options); @@ -219,7 +220,8 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::LoopUnrollOp::applyToOne(Operation *op, +transform::LoopUnrollOp::applyToOne(transform::TransformRewriter &rewriter, + Operation *op, transform::ApplyToEachResultList &results, transform::TransformState &state) { LogicalResult result(failure()); @@ -241,7 +243,8 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::LoopCoalesceOp::applyToOne(Operation *op, +transform::LoopCoalesceOp::applyToOne(transform::TransformRewriter &rewriter, + Operation *op, transform::ApplyToEachResultList &results, transform::TransformState &state) { LogicalResult result(failure()); @@ -276,12 +279,10 @@ } DiagnosedSilenceableFailure transform::TakeAssumedBranchOp::applyToOne( - scf::IfOp ifOp, transform::ApplyToEachResultList &results, + transform::TransformRewriter &rewriter, scf::IfOp ifOp, + transform::ApplyToEachResultList &results, transform::TransformState &state) { - TrackingListener listener(state, *this); - IRRewriter rewriter(ifOp->getContext(), &listener); rewriter.setInsertionPoint(ifOp); - Region ®ion = getTakeElseBranch() ? ifOp.getElseRegion() : ifOp.getThenRegion(); if (!llvm::hasSingleElement(region)) { diff --git a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp --- a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp +++ b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp @@ -123,7 +123,8 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::MakeLoopIndependentOp::applyToOne( - Operation *target, transform::ApplyToEachResultList &results, + transform::TransformRewriter &rewriter, Operation *target, + transform::ApplyToEachResultList &results, transform::TransformState &state) { // Gather IVs. SmallVector ivs; @@ -141,7 +142,6 @@ } // Rewrite IR. - IRRewriter rewriter(target->getContext()); FailureOr replacement = failure(); if (auto padOp = dyn_cast(target)) { replacement = tensor::buildIndependentOp(rewriter, padOp, ivs); diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp @@ -896,12 +896,32 @@ return diag; } + // Prepare rewriter and listener. + transform::ErrorCheckingTrackingListener trackingListener(*this, transform); + transform::TransformRewriter rewriter(transform->getContext()); + rewriter.setListener(&trackingListener); + // Compute the result but do not short-circuit the silenceable failure case as // we still want the handles to propagate properly so the "suppress" mode can // proceed on a best effort basis. transform::TransformResults results(transform->getNumResults()); - DiagnosedSilenceableFailure result(transform.apply(results, *this)); + DiagnosedSilenceableFailure result(transform.apply(rewriter, results, *this)); compactOpHandles(); + + // Error handling: fail if transform or listener failed. + DiagnosedSilenceableFailure trackingFailure = + trackingListener.checkAndResetError(); + if (!trackingFailure.succeeded()) { + if (result.succeeded()) { + result = std::move(trackingFailure); + } else { + // Transform op errors have precedence, report those first. + if (result.isSilenceableFailure()) + result.attachNote() << "tracking listener also failed: " + << trackingFailure.getMessage(); + (void)trackingFailure.silence(); + } + } if (result.isDefiniteFailure()) return result; @@ -1161,6 +1181,16 @@ // TrackingListener //===----------------------------------------------------------------------===// +transform::TrackingListener::TrackingListener(TransformState &state, + TransformOpInterface op) + : TransformState::Extension(state), transformOp(op) { + if (op) { + for (OpOperand *opOperand : transformOp.getConsumedHandleOpOperands()) { + consumedHandles.insert(opOperand->get()); + } + } +} + Operation *transform::TrackingListener::getCommonDefiningOp(ValueRange values) { Operation *defOp = nullptr; for (Value v : values) { @@ -1267,15 +1297,34 @@ // Op is not tracked. return; } + + // Helper function to check if the current transform op consumes any handle + // that is mapped to `op`. + // + // Note: If a handle was consumed, there shouldn't be any alive users, so it + // is not really necessary to check for consumed handles. However, in case + // there are indeed alive handles that were consumed (which is undefined + // behavior) and a replacement op could not be found, we want to fail with a + // nicer error message: "op uses a handle invalidated..." instead of "could + // not find replacement op". This nicer error is produced later. + auto handleWasConsumed = [&] { + return llvm::any_of(opHandles, + [&](Value h) { return consumedHandles.contains(h); }); + }; + + // Helper function to check if the handle is alive. auto hasAliveUser = [&]() { - for (Value v : opHandles) + for (Value v : opHandles) { for (Operation *user : v.getUsers()) - if (!happensBefore(user, transformOp)) + if (user != transformOp && !happensBefore(user, transformOp)) return true; + } return false; }; - if (!hasAliveUser()) { - // The op is tracked but the corresponding handles are dead. + + if (!hasAliveUser() || handleWasConsumed()) { + // The op is tracked but the corresponding handles are dead or were + // consumed. Drop the op form the mapping. (void)replacePayloadOp(op, nullptr); return; } @@ -1326,6 +1375,13 @@ ++errorCounter; } +//===----------------------------------------------------------------------===// +// TransformRewriter +//===----------------------------------------------------------------------===// + +transform::TransformRewriter::TransformRewriter(MLIRContext *ctx) + : RewriterBase(ctx) {} + //===----------------------------------------------------------------------===// // Utilities for TransformEachOpTrait. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -1,4 +1,4 @@ -//===- TransformDialect.cpp - Transform dialect operations ----------------===// +//===- TransformOps.cpp - Transform dialect operations --------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Transform/IR/TransformOps.h" + #include "mlir/Dialect/Transform/IR/MatchInterfaces.h" #include "mlir/Dialect/Transform/IR/TransformAttrs.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" @@ -92,7 +93,8 @@ } DiagnosedSilenceableFailure -transform::AlternativesOp::apply(transform::TransformResults &results, +transform::AlternativesOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { SmallVector originals; if (Value scopeHandle = getScope()) @@ -199,7 +201,8 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::AnnotateOp::apply(transform::TransformResults &results, +transform::AnnotateOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { SmallVector targets = llvm::to_vector(state.getPayloadOps(getTarget())); @@ -235,10 +238,9 @@ // ApplyPatternsOp //===----------------------------------------------------------------------===// -DiagnosedSilenceableFailure -transform::ApplyPatternsOp::applyToOne(Operation *target, - ApplyToEachResultList &results, - transform::TransformState &state) { +DiagnosedSilenceableFailure transform::ApplyPatternsOp::applyToOne( + transform::TransformRewriter &rewriter, Operation *target, + ApplyToEachResultList &results, transform::TransformState &state) { // Gather all specified patterns. MLIRContext *ctx = target->getContext(); RewritePatternSet patterns(ctx); @@ -346,7 +348,8 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::CastOp::applyToOne(Operation *target, ApplyToEachResultList &results, +transform::CastOp::applyToOne(transform::TransformRewriter &rewriter, + Operation *target, ApplyToEachResultList &results, transform::TransformState &state) { results.push_back(target); return DiagnosedSilenceableFailure::success(); @@ -408,7 +411,8 @@ } DiagnosedSilenceableFailure -transform::ForeachMatchOp::apply(transform::TransformResults &results, +transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { SmallVector> matchActionPairs; @@ -706,7 +710,8 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::ForeachOp::apply(transform::TransformResults &results, +transform::ForeachOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { SmallVector> resultOps(getNumResults(), {}); @@ -795,6 +800,7 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::GetClosestIsolatedParentOp::apply( + transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { SetVector parents; for (Operation *target : state.getPayloadOps(getTarget())) { @@ -818,7 +824,8 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::GetConsumersOfResult::apply(transform::TransformResults &results, +transform::GetConsumersOfResult::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { int64_t resultNumber = getResultNumber(); auto payloadOps = state.getPayloadOps(getTarget()); @@ -843,7 +850,8 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::GetDefiningOp::apply(transform::TransformResults &results, +transform::GetDefiningOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { SmallVector definingOps; for (Value v : state.getPayloadValues(getTarget())) { @@ -864,7 +872,8 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::GetProducerOfOperand::apply(transform::TransformResults &results, +transform::GetProducerOfOperand::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { int64_t operandNumber = getOperandNumber(); SmallVector producers; @@ -892,7 +901,8 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::GetResultOp::apply(transform::TransformResults &results, +transform::GetResultOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { int64_t resultNumber = getResultNumber(); SmallVector opResults; @@ -943,7 +953,8 @@ } DiagnosedSilenceableFailure -transform::IncludeOp::apply(transform::TransformResults &results, +transform::IncludeOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { auto callee = SymbolTable::lookupNearestSymbolFrom( getOperation(), getTarget()); @@ -1081,7 +1092,8 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::MatchParamCmpIOp::apply(transform::TransformResults &results, +transform::MatchParamCmpIOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { auto signedAPIntAsString = [&](APInt value) { std::string str; @@ -1167,7 +1179,8 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::ParamConstantOp::apply(transform::TransformResults &results, +transform::ParamConstantOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { results.setParams(cast(getParam()), {getValue()}); return DiagnosedSilenceableFailure::success(); @@ -1178,7 +1191,8 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::MergeHandlesOp::apply(transform::TransformResults &results, +transform::MergeHandlesOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { SmallVector operations; for (Value operand : getHandles()) @@ -1221,7 +1235,8 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::NamedSequenceOp::apply(transform::TransformResults &results, +transform::NamedSequenceOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { // Nothing to do here. return DiagnosedSilenceableFailure::success(); @@ -1387,7 +1402,8 @@ } DiagnosedSilenceableFailure -transform::SplitHandleOp::apply(transform::TransformResults &results, +transform::SplitHandleOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { int64_t numPayloadOps = llvm::range_size(state.getPayloadOps(getHandle())); auto produceNumOpsError = [&]() { @@ -1448,7 +1464,8 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::ReplicateOp::apply(transform::TransformResults &results, +transform::ReplicateOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { unsigned numRepetitions = llvm::range_size(state.getPayloadOps(getPattern())); for (const auto &en : llvm::enumerate(getHandles())) { @@ -1488,7 +1505,8 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::SequenceOp::apply(transform::TransformResults &results, +transform::SequenceOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { // Map the entry block argument to the list of operations. auto scope = state.make_region_scope(*getBodyBlock()->getParent()); @@ -1778,7 +1796,8 @@ } DiagnosedSilenceableFailure -transform::PrintOp::apply(transform::TransformResults &results, +transform::PrintOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { llvm::outs() << "[[[ IR printer: "; if (getName().has_value()) diff --git a/mlir/lib/Dialect/Transform/PDLExtension/PDLExtensionOps.cpp b/mlir/lib/Dialect/Transform/PDLExtension/PDLExtensionOps.cpp --- a/mlir/lib/Dialect/Transform/PDLExtension/PDLExtensionOps.cpp +++ b/mlir/lib/Dialect/Transform/PDLExtension/PDLExtensionOps.cpp @@ -138,7 +138,8 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::PDLMatchOp::apply(transform::TransformResults &results, +transform::PDLMatchOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { auto *extension = state.getExtension(); assert(extension && @@ -167,7 +168,8 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::WithPDLPatternsOp::apply(transform::TransformResults &results, +transform::WithPDLPatternsOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { TransformOpInterface transformOp = nullptr; for (Operation &nested : getBody().front()) { diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp @@ -45,7 +45,8 @@ return llvm::StringLiteral("transform.test_transform_op"); } - DiagnosedSilenceableFailure apply(transform::TransformResults &results, + DiagnosedSilenceableFailure apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { InFlightDiagnostic remark = emitRemark() << "applying transformation"; if (Attribute message = getMessage()) @@ -98,7 +99,8 @@ "transform.test_transform_unrestricted_op_no_interface"); } - DiagnosedSilenceableFailure apply(transform::TransformResults &results, + DiagnosedSilenceableFailure apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { return DiagnosedSilenceableFailure::success(); } @@ -110,6 +112,7 @@ DiagnosedSilenceableFailure mlir::test::TestProduceSelfHandleOrForwardOperandOp::apply( + transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { if (getOperation()->getNumOperands() != 0) { results.set(cast(getResult()), @@ -129,6 +132,7 @@ DiagnosedSilenceableFailure mlir::test::TestProduceValueHandleToSelfOperand::apply( + transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { results.setValues(llvm::cast(getOut()), getIn()); return DiagnosedSilenceableFailure::success(); @@ -143,7 +147,8 @@ DiagnosedSilenceableFailure mlir::test::TestProduceValueHandleToResult::applyToOne( - Operation *target, transform::ApplyToEachResultList &results, + transform::TransformRewriter &rewriter, Operation *target, + transform::ApplyToEachResultList &results, transform::TransformState &state) { if (target->getNumResults() <= getNumber()) return emitSilenceableError() << "payload has no result #" << getNumber(); @@ -160,7 +165,8 @@ DiagnosedSilenceableFailure mlir::test::TestProduceValueHandleToArgumentOfParentBlock::applyToOne( - Operation *target, transform::ApplyToEachResultList &results, + transform::TransformRewriter &rewriter, Operation *target, + transform::ApplyToEachResultList &results, transform::TransformState &state) { if (!target->getBlock()) return emitSilenceableError() << "payload has no parent block"; @@ -183,7 +189,8 @@ } DiagnosedSilenceableFailure -mlir::test::TestConsumeOperand::apply(transform::TransformResults &results, +mlir::test::TestConsumeOperand::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { return DiagnosedSilenceableFailure::success(); } @@ -197,6 +204,7 @@ } DiagnosedSilenceableFailure mlir::test::TestConsumeOperandOfOpKindOrFail::apply( + transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { auto payload = state.getPayloadOps(getOperand()); assert(llvm::hasSingleElement(payload) && "expected a single target op"); @@ -237,6 +245,7 @@ } DiagnosedSilenceableFailure mlir::test::TestPrintRemarkAtOperandOp::apply( + transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { auto payload = state.getPayloadOps(getOperand()); for (Operation *op : payload) @@ -252,6 +261,7 @@ } DiagnosedSilenceableFailure mlir::test::TestPrintRemarkAtOperandValue::apply( + transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { ArrayRef values = state.getPayloadValues(getIn()); for (Value value : values) { @@ -277,15 +287,16 @@ transform::onlyReadsPayload(effects); } -DiagnosedSilenceableFailure -mlir::test::TestAddTestExtensionOp::apply(transform::TransformResults &results, - transform::TransformState &state) { +DiagnosedSilenceableFailure mlir::test::TestAddTestExtensionOp::apply( + transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { state.addExtension(getMessageAttr()); return DiagnosedSilenceableFailure::success(); } DiagnosedSilenceableFailure mlir::test::TestCheckIfTestExtensionPresentOp::apply( + transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { auto *extension = state.getExtension(); if (!extension) { @@ -316,6 +327,7 @@ } DiagnosedSilenceableFailure mlir::test::TestRemapOperandPayloadToSelfOp::apply( + transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { auto *extension = state.getExtension(); if (!extension) @@ -337,14 +349,15 @@ } DiagnosedSilenceableFailure mlir::test::TestRemoveTestExtensionOp::apply( + transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { state.removeExtension(); return DiagnosedSilenceableFailure::success(); } -DiagnosedSilenceableFailure -mlir::test::TestReversePayloadOpsOp::apply(transform::TransformResults &results, - transform::TransformState &state) { +DiagnosedSilenceableFailure mlir::test::TestReversePayloadOpsOp::apply( + transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { auto payloadOps = state.getPayloadOps(getTarget()); auto reversedOps = llvm::to_vector(llvm::reverse(payloadOps)); results.set(llvm::cast(getResult()), reversedOps); @@ -352,6 +365,7 @@ } DiagnosedSilenceableFailure mlir::test::TestTransformOpWithRegions::apply( + transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { return DiagnosedSilenceableFailure::success(); } @@ -361,6 +375,7 @@ DiagnosedSilenceableFailure mlir::test::TestBranchingTransformOpTerminator::apply( + transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { return DiagnosedSilenceableFailure::success(); } @@ -369,6 +384,7 @@ SmallVectorImpl &effects) {} DiagnosedSilenceableFailure mlir::test::TestEmitRemarkAndEraseOperandOp::apply( + transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { emitRemark() << getRemark(); for (Operation *op : state.getPayloadOps(getTarget())) @@ -386,7 +402,8 @@ } DiagnosedSilenceableFailure mlir::test::TestWrongNumberOfResultsOp::applyToOne( - Operation *target, transform::ApplyToEachResultList &results, + transform::TransformRewriter &rewriter, Operation *target, + transform::ApplyToEachResultList &results, transform::TransformState &state) { OperationState opState(target->getLoc(), "foo"); results.push_back(OpBuilder(target).create(opState)); @@ -395,7 +412,8 @@ DiagnosedSilenceableFailure mlir::test::TestWrongNumberOfMultiResultsOp::applyToOne( - Operation *target, transform::ApplyToEachResultList &results, + transform::TransformRewriter &rewriter, Operation *target, + transform::ApplyToEachResultList &results, transform::TransformState &state) { static int count = 0; if (count++ == 0) { @@ -407,7 +425,8 @@ DiagnosedSilenceableFailure mlir::test::TestCorrectNumberOfMultiResultsOp::applyToOne( - Operation *target, transform::ApplyToEachResultList &results, + transform::TransformRewriter &rewriter, Operation *target, + transform::ApplyToEachResultList &results, transform::TransformState &state) { OperationState opState(target->getLoc(), "foo"); results.push_back(OpBuilder(target).create(opState)); @@ -417,7 +436,8 @@ DiagnosedSilenceableFailure mlir::test::TestMixedNullAndNonNullResultsOp::applyToOne( - Operation *target, transform::ApplyToEachResultList &results, + transform::TransformRewriter &rewriter, Operation *target, + transform::ApplyToEachResultList &results, transform::TransformState &state) { OperationState opState(target->getLoc(), "foo"); results.push_back(nullptr); @@ -427,7 +447,8 @@ DiagnosedSilenceableFailure mlir::test::TestMixedSuccessAndSilenceableOp::applyToOne( - Operation *target, transform::ApplyToEachResultList &results, + transform::TransformRewriter &rewriter, Operation *target, + transform::ApplyToEachResultList &results, transform::TransformState &state) { if (target->hasAttr("target_me")) return DiagnosedSilenceableFailure::success(); @@ -436,6 +457,7 @@ DiagnosedSilenceableFailure mlir::test::TestPrintNumberOfAssociatedPayloadIROps::apply( + transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { if (!getHandle()) emitRemark() << 0; @@ -449,7 +471,8 @@ } DiagnosedSilenceableFailure -mlir::test::TestCopyPayloadOp::apply(transform::TransformResults &results, +mlir::test::TestCopyPayloadOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { results.set(llvm::cast(getCopy()), state.getPayloadOps(getHandle())); @@ -498,6 +521,7 @@ DiagnosedSilenceableFailure mlir::test::TestReportNumberOfTrackedHandlesNestedUnder::apply( + transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { int64_t count = 0; for (Operation *op : state.getPayloadOps(getTarget())) { @@ -520,7 +544,8 @@ } DiagnosedSilenceableFailure -mlir::test::TestPrintParamOp::apply(transform::TransformResults &results, +mlir::test::TestPrintParamOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { std::string str; llvm::raw_string_ostream os(str); @@ -537,7 +562,8 @@ } DiagnosedSilenceableFailure -mlir::test::TestAddToParamOp::apply(transform::TransformResults &results, +mlir::test::TestAddToParamOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { SmallVector values(/*Size=*/1, /*Value=*/0); if (Value param = getParam()) { @@ -559,6 +585,7 @@ DiagnosedSilenceableFailure mlir::test::TestProduceParamWithNumberOfTestOps::apply( + transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { Builder builder(getContext()); SmallVector result = llvm::to_vector( @@ -577,6 +604,7 @@ DiagnosedSilenceableFailure mlir::test::TestProduceIntegerParamWithTypeOp::apply( + transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { Attribute zero = IntegerAttr::get(getType(), 0); results.setParams(llvm::cast(getResult()), zero); @@ -599,7 +627,8 @@ DiagnosedSilenceableFailure mlir::test::TestProduceTransformParamOrForwardOperandOp::applyToOne( - Operation *target, ::transform::ApplyToEachResultList &results, + transform::TransformRewriter &rewriter, Operation *target, + ::transform::ApplyToEachResultList &results, ::transform::TransformState &state) { Builder builder(getContext()); if (getFirstResultIsParam()) { @@ -625,6 +654,7 @@ } DiagnosedSilenceableFailure mlir::test::TestProduceNullPayloadOp::apply( + transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { SmallVector null({nullptr}); results.set(llvm::cast(getOut()), null); @@ -632,6 +662,7 @@ } DiagnosedSilenceableFailure mlir::test::TestProduceEmptyPayloadOp::apply( + transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { results.set(cast(getOut()), {}); return DiagnosedSilenceableFailure::success(); @@ -642,9 +673,9 @@ transform::producesHandle(getOut(), effects); } -DiagnosedSilenceableFailure -mlir::test::TestProduceNullParamOp::apply(transform::TransformResults &results, - transform::TransformState &state) { +DiagnosedSilenceableFailure mlir::test::TestProduceNullParamOp::apply( + transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { results.setParams(llvm::cast(getOut()), Attribute()); return DiagnosedSilenceableFailure::success(); } @@ -654,9 +685,9 @@ transform::producesHandle(getOut(), effects); } -DiagnosedSilenceableFailure -mlir::test::TestProduceNullValueOp::apply(transform::TransformResults &results, - transform::TransformState &state) { +DiagnosedSilenceableFailure mlir::test::TestProduceNullValueOp::apply( + transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { results.setValues(llvm::cast(getOut()), Value()); return DiagnosedSilenceableFailure::success(); } @@ -676,6 +707,7 @@ } DiagnosedSilenceableFailure mlir::test::TestRequiredMemoryEffectsOp::apply( + transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { results.set(llvm::cast(getOut()), state.getPayloadOps(getIn())); return DiagnosedSilenceableFailure::success(); @@ -694,10 +726,9 @@ } DiagnosedSilenceableFailure -mlir::test::TestTrackedRewriteOp::apply(transform::TransformResults &results, +mlir::test::TestTrackedRewriteOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { - transform::ErrorCheckingTrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); int64_t numIterations = 0; // `getPayloadOps` returns an iterator that skips ops that are erased in the diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td @@ -72,6 +72,7 @@ let cppNamespace = "::mlir::test"; let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::Operation *target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -89,6 +90,7 @@ let cppNamespace = "::mlir::test"; let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::Operation *target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -116,6 +118,7 @@ let cppNamespace = "::mlir::test"; let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::Operation *target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state) { @@ -260,6 +263,7 @@ let cppNamespace = "::mlir::test"; let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::Operation * target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -277,6 +281,7 @@ let cppNamespace = "::mlir::test"; let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::Operation * target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -295,6 +300,7 @@ let cppNamespace = "::mlir::test"; let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::Operation * target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -313,6 +319,7 @@ let cppNamespace = "::mlir::test"; let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::Operation * target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -329,6 +336,7 @@ let cppNamespace = "::mlir::test"; let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::Operation * target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -428,6 +436,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::Operation *target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -498,7 +507,8 @@ let cppNamespace = "::mlir::test"; let extraClassDeclaration = [{ - DiagnosedSilenceableFailure apply(transform::TransformResults &results, + DiagnosedSilenceableFailure apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { llvm_unreachable("op should not be used as a transform"); return DiagnosedSilenceableFailure::definiteFailure();