diff --git a/mlir/examples/transform/Ch2/lib/MyExtension.cpp b/mlir/examples/transform/Ch2/lib/MyExtension.cpp --- a/mlir/examples/transform/Ch2/lib/MyExtension.cpp +++ b/mlir/examples/transform/Ch2/lib/MyExtension.cpp @@ -85,6 +85,8 @@ // The silenceable failure additionally carries a Diagnostic that can be emitted // to the user. ::mlir::DiagnosedSilenceableFailure mlir::transform::ChangeCallTargetOp::apply( + // The rewriter that should be used when modifying IR. + ::mlir::transform::TransformRewriter &rewriter, // The list of payload IR entities that will be associated with the // transform IR values defined by this transform operation. In this case, it // can remain empty as there are no results. diff --git a/mlir/examples/transform/Ch3/include/MyExtension.td b/mlir/examples/transform/Ch3/include/MyExtension.td --- a/mlir/examples/transform/Ch3/include/MyExtension.td +++ b/mlir/examples/transform/Ch3/include/MyExtension.td @@ -60,6 +60,7 @@ // Declare the function implementing the interface for a single payload operation. let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::func::CallOp call, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -89,6 +90,7 @@ // Declare the function implementing the interface for a single payload operation. let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::CallOpInterface call, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); diff --git a/mlir/examples/transform/Ch3/lib/MyExtension.cpp b/mlir/examples/transform/Ch3/lib/MyExtension.cpp --- a/mlir/examples/transform/Ch3/lib/MyExtension.cpp +++ b/mlir/examples/transform/Ch3/lib/MyExtension.cpp @@ -113,6 +113,8 @@ // to the user. ::mlir::DiagnosedSilenceableFailure mlir::transform::ChangeCallTargetOp::applyToOne( + // The rewriter that should be used when modifying IR. + ::mlir::transform::TransformRewriter &rewriter, // The single payload operation to which the transformation is applied. ::mlir::func::CallOp call, // The payload IR entities that will be appended to lists associated with @@ -146,27 +148,27 @@ // CallToOp //===---------------------------------------------------------------------===// -static mlir::Operation *replaceCallWithOp(mlir::CallOpInterface call) { +static mlir::Operation *replaceCallWithOp(mlir::RewriterBase &rewriter, + mlir::CallOpInterface call) { // Construct an operation from an unregistered dialect. This is discouraged // and is only used here for brevity of the overall example. mlir::OperationState state(call.getLoc(), "my.mm4"); state.types.assign(call->result_type_begin(), call->result_type_end()); state.operands.assign(call->operand_begin(), call->operand_end()); - mlir::OpBuilder builder(call); - mlir::Operation *replacement = builder.create(state); - call->replaceAllUsesWith(replacement->getResults()); - call->erase(); + mlir::Operation *replacement = rewriter.create(state); + rewriter.replaceOp(call, replacement->getResults()); return replacement; } // See above for the signature description. mlir::DiagnosedSilenceableFailure mlir::transform::CallToOp::applyToOne( - mlir::CallOpInterface call, mlir::transform::ApplyToEachResultList &results, + mlir::transform::TransformRewriter &rewriter, mlir::CallOpInterface call, + mlir::transform::ApplyToEachResultList &results, mlir::transform::TransformState &state) { // Dispatch to the actual transformation. - Operation *replacement = replaceCallWithOp(call); + Operation *replacement = replaceCallWithOp(rewriter, call); // Associate the payload operation produced by the rewrite with the result // handle of this transform operation. diff --git a/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td b/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td --- a/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td +++ b/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td @@ -152,6 +152,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::tensor::EmptyOp target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); diff --git a/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td b/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td --- a/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td +++ b/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td @@ -126,6 +126,7 @@ }]; let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::Operation *target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -189,6 +190,7 @@ }]; let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::Operation *target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -155,6 +155,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::linalg::LinalgOp target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -270,6 +271,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::linalg::LinalgOp target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -313,6 +315,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::linalg::GenericOp target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -350,6 +353,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::tensor::PackOp target, ::mlir::transform::ApplyToEachResultList &transformResults, ::mlir::transform::TransformState &state); @@ -389,6 +393,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::tensor::UnPackOp target, ::mlir::transform::ApplyToEachResultList &transformResults, ::mlir::transform::TransformState &state); @@ -531,6 +536,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::linalg::LinalgOp target, ::mlir::transform::ApplyToEachResultList &results, TransformState &state); @@ -837,6 +843,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::linalg::LinalgOp target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -931,6 +938,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::tensor::PadOp, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -978,6 +986,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::linalg::LinalgOp target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -1051,6 +1060,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::linalg::LinalgOp target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -1098,6 +1108,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::Operation *target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -1309,6 +1320,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::linalg::LinalgOp target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -1413,6 +1425,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::linalg::LinalgOp target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -1522,6 +1535,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::linalg::LinalgOp target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -1726,6 +1740,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure apply( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::transform::TransformResults &transformResults, ::mlir::transform::TransformState &state); @@ -1862,6 +1877,7 @@ ]; let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::Operation *target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -1911,6 +1927,7 @@ let extraClassDeclaration = [{ // TODO: applyToOne. ::mlir::DiagnosedSilenceableFailure apply( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::transform::TransformResults &transformResults, ::mlir::transform::TransformState &state); @@ -1955,7 +1972,8 @@ ]; let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( - ::mlir::func::FuncOp target, + ::mlir::transform::TransformRewriter &rewriter, + ::mlir::func::FuncOp target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); }]; @@ -2031,6 +2049,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::linalg::LinalgOp target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -2079,6 +2098,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::Operation *target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -2117,6 +2137,7 @@ ]; let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::Operation *target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); diff --git a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td --- a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td +++ b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td @@ -158,6 +158,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::Operation *target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td --- a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td +++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td @@ -126,6 +126,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::scf::ForOp target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -168,6 +169,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::scf::ForOp target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -201,6 +203,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::Operation *target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -228,6 +231,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::Operation *target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -269,6 +273,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::scf::IfOp ifOp, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); diff --git a/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td b/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td --- a/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td +++ b/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td @@ -124,6 +124,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::Operation *target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); diff --git a/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h --- a/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h +++ b/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h @@ -50,7 +50,8 @@ return success(); } - DiagnosedSilenceableFailure apply(TransformResults &results, + DiagnosedSilenceableFailure apply(transform::TransformRewriter &rewriter, + TransformResults &results, TransformState &state) { Value operandHandle = cast(this->getOperation()).getOperandHandle(); auto payload = state.getPayloadOps(operandHandle); @@ -90,7 +91,8 @@ return success(); } - DiagnosedSilenceableFailure apply(TransformResults &results, + DiagnosedSilenceableFailure apply(transform::TransformRewriter &rewriter, + TransformResults &results, TransformState &state) { Value operandHandle = cast(this->getOperation()).getOperandHandle(); ValueRange payload = state.getPayloadValues(operandHandle); diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h @@ -18,11 +18,11 @@ #include "mlir/Support/LogicalResult.h" namespace mlir { - namespace transform { class TransformOpInterface; class TransformResults; +class TransformRewriter; class TransformState; using Param = Attribute; @@ -854,8 +854,7 @@ public TransformState::Extension { public: /// Create a new TrackingListener for usage in the specified transform op. - explicit TrackingListener(TransformState &state, TransformOpInterface op) - : TransformState::Extension(state), transformOp(op) {} + TrackingListener(TransformState &state, TransformOpInterface op); protected: /// Return a replacement payload op for the given op, which is going to be @@ -935,6 +934,9 @@ /// The transform op in which this TrackingListener is used. TransformOpInterface transformOp; + + /// The handles that are consumed by the transform op. + DenseSet consumedHandles; }; /// A specialized listener that keeps track of cases in which no replacement @@ -966,6 +968,15 @@ int64_t errorCounter = 0; }; +class TransformRewriter : public RewriterBase, + public TransformState::Extension { +protected: + friend class TransformState; + + /// Create a new TransformRewriter. + explicit TransformRewriter(MLIRContext *ctx, TransformState &state); +}; + /// This trait is supposed to be attached to Transform dialect operations that /// can be standalone top-level transforms. Such operations typically contain /// other Transform dialect operations that can be executed following some @@ -1062,7 +1073,8 @@ /// 5. If any `applyToOne` return silenceableFailure, the transformation is /// considered silenceable. /// 6. Otherwise the transformation is considered successful. - DiagnosedSilenceableFailure apply(TransformResults &transformResults, + DiagnosedSilenceableFailure apply(transform::TransformRewriter &rewriter, + TransformResults &transformResults, TransformState &state); /// Checks that the op matches the expectations of this trait. @@ -1305,14 +1317,14 @@ /// `targets` contains operations of the same class and a silenceable failure /// is reported if it does not. template -DiagnosedSilenceableFailure -applyTransformToEach(TransformOpTy transformOp, Range &&targets, - SmallVectorImpl &results, - TransformState &state) { +DiagnosedSilenceableFailure applyTransformToEach( + TransformOpTy transformOp, TransformRewriter &rewriter, Range &&targets, + SmallVectorImpl &results, TransformState &state) { using OpTy = typename llvm::function_traits< - decltype(&TransformOpTy::applyToOne)>::template arg_t<0>; + decltype(&TransformOpTy::applyToOne)>::template arg_t<1>; static_assert(std::is_convertible::value, "expected transform function to take an operation"); + OpBuilder::InsertionGuard g(rewriter); SmallVector silenceableStack; unsigned expectedNumResults = transformOp->getNumResults(); @@ -1329,8 +1341,9 @@ ApplyToEachResultList partialResults; partialResults.reserve(expectedNumResults); Location specificOpLoc = specificOp->getLoc(); + rewriter.setInsertionPoint(specificOp); DiagnosedSilenceableFailure res = - transformOp.applyToOne(specificOp, partialResults, state); + transformOp.applyToOne(rewriter, specificOp, partialResults, state); if (res.isDefiniteFailure()) return DiagnosedSilenceableFailure::definiteFailure(); @@ -1359,7 +1372,8 @@ template mlir::DiagnosedSilenceableFailure mlir::transform::TransformEachOpTrait::apply( - TransformResults &transformResults, TransformState &state) { + TransformRewriter &rewriter, TransformResults &transformResults, + TransformState &state) { auto targets = state.getPayloadOps(this->getOperation()->getOperand(0)); // Step 1. Handle the corner case where no target is specified. @@ -1384,7 +1398,7 @@ // corresponding results entry. SmallVector results; DiagnosedSilenceableFailure result = detail::applyTransformToEach( - cast(this->getOperation()), targets, results, state); + cast(this->getOperation()), rewriter, targets, results, state); // Step 3. Propagate the definite failure if any and bail out. if (result.isDefiniteFailure()) diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td @@ -45,6 +45,7 @@ /*returnType=*/"::mlir::DiagnosedSilenceableFailure", /*name=*/"apply", /*arguments=*/(ins + "::mlir::transform::TransformRewriter &":$rewriter, "::mlir::transform::TransformResults &":$transformResults, "::mlir::transform::TransformState &":$state )>, diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td @@ -177,6 +177,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::Operation *target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -204,6 +205,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::Operation *target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); diff --git a/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp b/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp --- a/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp +++ b/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp @@ -63,7 +63,8 @@ } // namespace DiagnosedSilenceableFailure -SimplifyBoundedAffineOpsOp::apply(TransformResults &results, +SimplifyBoundedAffineOpsOp::apply(transform::TransformRewriter &rewriter, + TransformResults &results, TransformState &state) { // Get constraints for bounded values. SmallVector lbs; @@ -127,6 +128,8 @@ SimplifyAffineMinMaxOp>(getContext(), cstr); FrozenRewritePatternSet frozenPatterns(std::move(patterns)); GreedyRewriteConfig config; + config.listener = + static_cast(rewriter.getListener()); config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps; // Apply the simplification pattern to a fixpoint. if (failed(applyOpPatternsAndFold(targets, frozenPatterns, config))) { diff --git a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp --- a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp +++ b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp @@ -26,7 +26,8 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::OneShotBufferizeOp::apply(TransformResults &transformResults, +transform::OneShotBufferizeOp::apply(transform::TransformRewriter &rewriter, + TransformResults &transformResults, TransformState &state) { OneShotBufferizationOptions options; options.allowReturnAllocs = getAllowReturnAllocs(); @@ -71,10 +72,9 @@ modifiesPayload(effects); } -DiagnosedSilenceableFailure -transform::EliminateEmptyTensorsOp::apply(TransformResults &transformResults, - TransformState &state) { - IRRewriter rewriter(getContext()); +DiagnosedSilenceableFailure transform::EliminateEmptyTensorsOp::apply( + transform::TransformRewriter &rewriter, TransformResults &transformResults, + TransformState &state) { OneShotBufferizationOptions options; options.allowReturnAllocs = true; @@ -95,11 +95,9 @@ // EmptyTensorToAllocTensorOp //===----------------------------------------------------------------------===// -DiagnosedSilenceableFailure -EmptyTensorToAllocTensorOp::applyToOne(tensor::EmptyOp target, - ApplyToEachResultList &results, - transform::TransformState &state) { - IRRewriter rewriter(target->getContext()); +DiagnosedSilenceableFailure EmptyTensorToAllocTensorOp::applyToOne( + transform::TransformRewriter &rewriter, tensor::EmptyOp target, + ApplyToEachResultList &results, transform::TransformState &state) { rewriter.setInsertionPoint(target); auto alloc = rewriter.replaceOpWithNewOp( target, target.getType(), target.getDynamicSizes()); diff --git a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp --- a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp +++ b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp @@ -398,7 +398,7 @@ /// Alter kernel configuration of the given kernel. static DiagnosedSilenceableFailure -alterGpuLaunch(IRRewriter &rewriter, LaunchOp gpuLaunch, +alterGpuLaunch(RewriterBase &rewriter, LaunchOp gpuLaunch, TransformOpInterface transformOp, std::optional gridDimX = std::nullopt, std::optional gridDimY = std::nullopt, @@ -661,12 +661,10 @@ return DiagnosedSilenceableFailure::success(); } -DiagnosedSilenceableFailure -transform::MapForallToBlocks::applyToOne(Operation *target, - ApplyToEachResultList &results, - transform::TransformState &state) { +DiagnosedSilenceableFailure transform::MapForallToBlocks::applyToOne( + transform::TransformRewriter &rewriter, Operation *target, + ApplyToEachResultList &results, transform::TransformState &state) { LaunchOp gpuLaunch = dyn_cast(target); - IRRewriter rewriter(getContext()); auto transformOp = cast(getOperation()); if (!getGenerateGpuLaunch() && !gpuLaunch) { @@ -856,7 +854,8 @@ } DiagnosedSilenceableFailure transform::MapNestedForallToThreads::applyToOne( - Operation *target, ApplyToEachResultList &results, TransformState &state) { + transform::TransformRewriter &rewriter, Operation *target, + ApplyToEachResultList &results, TransformState &state) { LaunchOp gpuLaunch = dyn_cast(target); auto transformOp = cast(getOperation()); @@ -877,7 +876,6 @@ // Set the GPU launch configuration for the block dims early, this is not // subject to IR inspection. - IRRewriter rewriter(getContext()); diag = alterGpuLaunch(rewriter, gpuLaunch, transformOp, std::nullopt, std::nullopt, std::nullopt, blockDims[0], blockDims[1], blockDims[2]); diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -169,13 +169,11 @@ // BufferizeToAllocationOp //===----------------------------------------------------------------------===// -DiagnosedSilenceableFailure -transform::BufferizeToAllocationOp::apply(transform::TransformResults &results, - transform::TransformState &state) { +DiagnosedSilenceableFailure transform::BufferizeToAllocationOp::apply( + transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { Attribute memorySpace = getMemorySpace().has_value() ? getMemorySpace().value() : Attribute(); - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); auto transformed = llvm::to_vector( llvm::map_range(state.getPayloadValues(getTarget()), [&](Value v) { return linalg::bufferizeToAllocation(rewriter, v, memorySpace); @@ -196,7 +194,8 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::DecomposeOp::applyToOne(LinalgOp target, +transform::DecomposeOp::applyToOne(transform::TransformRewriter &rewriter, + LinalgOp target, transform::ApplyToEachResultList &results, transform::TransformState &state) { #define DOWNSCALE(trans) \ @@ -286,7 +285,8 @@ } DiagnosedSilenceableFailure -transform::FuseOp::apply(mlir::transform::TransformResults &transformResults, +transform::FuseOp::apply(transform::TransformRewriter &rewriter, + mlir::transform::TransformResults &transformResults, mlir::transform::TransformState &state) { SmallVector tileSizes = extractFromI64ArrayAttr(getTileSizes()); SmallVector tileInterchange = @@ -297,8 +297,6 @@ tilingOptions = tilingOptions.setTileSizes(tileSizes); scf::SCFTileAndFuseOptions tileAndFuseOptions; tileAndFuseOptions.tilingOptions = tilingOptions; - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); LogicalResult result = applyTilingToAll( rewriter, getOperation(), state.getPayloadOps(getTarget()), tileSizes.size() - llvm::count(tileSizes, 0), transformResults, @@ -721,7 +719,8 @@ } DiagnosedSilenceableFailure -transform::FuseIntoContainingOp::apply(transform::TransformResults &results, +transform::FuseIntoContainingOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { SmallVector fusedOps; auto producerOps = state.getPayloadOps(getProducerOp()); @@ -764,8 +763,6 @@ return failure(); }; - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); while (!remainingProducers.empty()) { auto nextProducer = getNextProducer(); if (failed(nextProducer)) { @@ -842,7 +839,8 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::GeneralizeOp::applyToOne(LinalgOp target, +transform::GeneralizeOp::applyToOne(transform::TransformRewriter &rewriter, + LinalgOp target, transform::ApplyToEachResultList &results, transform::TransformState &state) { // Exit early if no transformation is needed. @@ -850,8 +848,6 @@ results.push_back(target); return DiagnosedSilenceableFailure::success(); } - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); rewriter.setInsertionPoint(target); FailureOr generic = generalizeNamedOp(rewriter, target); if (succeeded(generic)) { @@ -866,7 +862,8 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::InterchangeOp::applyToOne(GenericOp target, +transform::InterchangeOp::applyToOne(transform::TransformRewriter &rewriter, + GenericOp target, transform::ApplyToEachResultList &results, transform::TransformState &state) { ArrayRef interchangeVector = getIteratorInterchange(); @@ -875,8 +872,6 @@ results.push_back(target); return DiagnosedSilenceableFailure::success(); } - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); FailureOr res = interchangeGenericOp(rewriter, target, SmallVector(interchangeVector.begin(), @@ -904,10 +899,9 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::LowerPackOp::applyToOne( - tensor::PackOp target, transform::ApplyToEachResultList &transformResults, + transform::TransformRewriter &rewriter, tensor::PackOp target, + transform::ApplyToEachResultList &transformResults, transform::TransformState &state) { - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); rewriter.setInsertionPoint(target); FailureOr res = lowerPack(rewriter, target); if (failed(res)) { @@ -925,10 +919,9 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::LowerUnPackOp::applyToOne( - tensor::UnPackOp target, transform::ApplyToEachResultList &transformResults, + transform::TransformRewriter &rewriter, tensor::UnPackOp target, + transform::ApplyToEachResultList &transformResults, transform::TransformState &state) { - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); rewriter.setInsertionPoint(target); FailureOr res = lowerUnPack(rewriter, target); if (failed(res)) { @@ -964,7 +957,8 @@ } DiagnosedSilenceableFailure -transform::MatchOp::apply(transform::TransformResults &results, +transform::MatchOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { llvm::StringSet<> strs; if (getOps().has_value()) @@ -1053,8 +1047,8 @@ } DiagnosedSilenceableFailure transform::MultiTileSizesOp::applyToOne( - LinalgOp target, transform::ApplyToEachResultList &results, - TransformState &state) { + transform::TransformRewriter &rewriter, LinalgOp target, + transform::ApplyToEachResultList &results, TransformState &state) { if (isa(getLowSize().getType())) { if (target.hasDynamicShape()) { auto diag = emitSilenceableError() @@ -1155,7 +1149,8 @@ } DiagnosedSilenceableFailure -transform::PackOp::apply(transform::TransformResults &transformResults, +transform::PackOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &transformResults, transform::TransformState &state) { auto targetOps = state.getPayloadOps(getTarget()); // If nothing to pack, propagate success. @@ -1184,8 +1179,6 @@ DiagnosedSilenceableFailure status = unpackSingleIndexResultPayloadOperations( state, *this, packedSizes, getMixedPackedSizes()); - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); rewriter.setInsertionPoint(linalgOp); FailureOr maybeResult = pack(rewriter, linalgOp, packedSizes); if (failed(maybeResult)) @@ -1364,11 +1357,10 @@ } DiagnosedSilenceableFailure -PackGreedilyOp::apply(transform::TransformResults &transformResults, +PackGreedilyOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &transformResults, transform::TransformState &state) { SmallVector results; - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); for (Operation *op : state.getPayloadOps(getTarget())) { auto linalgOp = dyn_cast(op); if (!linalgOp) @@ -1464,7 +1456,8 @@ } DiagnosedSilenceableFailure -transform::PackTransposeOp::apply(transform::TransformResults &transformResults, +transform::PackTransposeOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &transformResults, transform::TransformState &state) { auto packOrUnpackOps = state.getPayloadOps(getTargetPackOrUnPackOp()); auto linalgOps = state.getPayloadOps(getTargetLinalgOp()); @@ -1542,8 +1535,6 @@ assert(packOp && linalgOp && "unexpected null op"); // Step 3. Actually transpose the ops. - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); FailureOr res = packTranspose( rewriter, packOp, linalgOp, unPackOp, getOuterPerm(), getInnerPerm()); // Preconditions have been checked, it is an error to fail here. @@ -1568,7 +1559,8 @@ //===---------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::PadOp::applyToOne(LinalgOp target, +transform::PadOp::applyToOne(transform::TransformRewriter &rewriter, + LinalgOp target, transform::ApplyToEachResultList &results, transform::TransformState &state) { // Convert the integer packing flags to booleans. @@ -1616,8 +1608,6 @@ transposePaddings.push_back( extractFromI64ArrayAttr(cast(transposeVector))); - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); LinalgOp paddedOp; SmallVector paddingDimensions = extractFromI64ArrayAttr(getPaddingDimensions()); @@ -1684,6 +1674,7 @@ //===---------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::HoistPadBuildPackingLoopNestOp::apply( + transform::TransformRewriter &rewriter, transform::TransformResults &transformResults, transform::TransformState &state) { auto targetOps = state.getPayloadOps(getTarget()); @@ -1700,8 +1691,6 @@ if (!padOp || !loopOp) return emitDefiniteFailure() << "requires exactly 2 non-null handles"; - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); FailureOr result = linalg::detail::buildPackingLoopNest(rewriter, padOp, loopOp, getTranspose()); @@ -1740,13 +1729,12 @@ } DiagnosedSilenceableFailure -transform::HoistPadOp::applyToOne(tensor::PadOp target, +transform::HoistPadOp::applyToOne(transform::TransformRewriter &rewriter, + tensor::PadOp target, transform::ApplyToEachResultList &results, transform::TransformState &state) { tensor::PadOp hoistedPadOp; SmallVector transposeOps; - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); FailureOr result = hoistPaddingOnTensors(rewriter, target, getNumLoops(), getTranspose(), hoistedPadOp, transposeOps); @@ -1779,7 +1767,8 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::PromoteOp::applyToOne(LinalgOp target, +transform::PromoteOp::applyToOne(transform::TransformRewriter &rewriter, + LinalgOp target, transform::ApplyToEachResultList &results, transform::TransformState &state) { LinalgPromotionOptions promotionOptions; @@ -1829,8 +1818,6 @@ if (failed(promoteSubviewsPrecondition(target, promotionOptions))) return emitDefaultDefiniteFailure(target); - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); rewriter.setInsertionPoint(target); FailureOr res = promoteSubViews(rewriter, target, promotionOptions); if (failed(res)) @@ -1844,7 +1831,8 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::ReplaceOp::apply(TransformResults &transformResults, +transform::ReplaceOp::apply(transform::TransformRewriter &rewriter, + TransformResults &transformResults, TransformState &state) { auto payload = state.getPayloadOps(getTarget()); @@ -1859,8 +1847,6 @@ } // Clone and replace. - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); Operation *pattern = &getBodyRegion().front().front(); SmallVector replacements; for (Operation *target : payload) { @@ -1904,7 +1890,8 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::ScalarizeOp::applyToOne(LinalgOp target, +transform::ScalarizeOp::applyToOne(transform::TransformRewriter &rewriter, + LinalgOp target, transform::ApplyToEachResultList &results, transform::TransformState &state) { scf::SCFTilingOptions tilingOptions; @@ -1916,8 +1903,6 @@ AffineMap map = target.getShapesToLoopsMap(); if (!map) return tileSizes; - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); SmallVector shapeSizes = affine::makeComposedFoldedMultiResultAffineApply(rewriter, loc, map, allShapeSizes); @@ -1931,8 +1916,6 @@ return tileSizes; }); SmallVector emptyTileSizes; - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); rewriter.setInsertionPoint(target); FailureOr maybeTilingResult = tileUsingSCFForOp( rewriter, cast(target.getOperation()), tilingOptions); @@ -1956,11 +1939,10 @@ DiagnosedSilenceableFailure transform::RewriteInDestinationPassingStyleOp::applyToOne( - Operation *target, transform::ApplyToEachResultList &results, + transform::TransformRewriter &rewriter, Operation *target, + transform::ApplyToEachResultList &results, transform::TransformState &state) { SmallVector res; - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); rewriter.setInsertionPoint(target); FailureOr maybeResult = TypeSwitch>(target) @@ -1978,13 +1960,12 @@ // SplitOp //===----------------------------------------------------------------------===// -DiagnosedSilenceableFailure SplitOp::apply(TransformResults &results, - TransformState &state) { +DiagnosedSilenceableFailure +SplitOp::apply(transform::TransformRewriter &rewriter, + TransformResults &results, TransformState &state) { // Collect the dynamic split points if provided. SmallVector payload = llvm::to_vector(state.getPayloadOps(getTarget())); - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); SmallVector splitPoints; splitPoints.reserve(payload.size()); if (getDynamicSplitPoint()) { @@ -2184,15 +2165,14 @@ } DiagnosedSilenceableFailure transform::SplitReductionOp::applyToOne( - LinalgOp target, transform::ApplyToEachResultList &results, + transform::TransformRewriter &rewriter, LinalgOp target, + transform::ApplyToEachResultList &results, transform::TransformState &state) { ControlSplitReductionFn splitFn = [&](LinalgOp) { return linalg::SplitReductionOptions{int64_t(getSplitFactor()), unsigned(getInsertSplitDimension()), bool(getInnerParallel())}; }; - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); rewriter.setInsertionPoint(target); FailureOr splitResult = (getUseScalingAlgorithm()) @@ -2230,10 +2210,9 @@ } DiagnosedSilenceableFailure transform::TileReductionUsingScfOp::applyToOne( - LinalgOp target, transform::ApplyToEachResultList &results, + transform::TransformRewriter &rewriter, LinalgOp target, + transform::ApplyToEachResultList &results, transform::TransformState &state) { - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); rewriter.setInsertionPoint(target); FailureOr result = scf::tileReductionUsingScf( rewriter, cast(target.getOperation()), @@ -2274,10 +2253,9 @@ } DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne( - LinalgOp target, transform::ApplyToEachResultList &results, + transform::TransformRewriter &rewriter, LinalgOp target, + transform::ApplyToEachResultList &results, transform::TransformState &state) { - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); rewriter.setInsertionPoint(target); SmallVector numThreads = getAsOpFoldResult(rewriter.getI64ArrayAttr(getNumThreads())); @@ -2363,7 +2341,8 @@ } DiagnosedSilenceableFailure -transform::TileOp::apply(TransformResults &transformResults, +transform::TileOp::apply(transform::TransformRewriter &rewriter, + TransformResults &transformResults, TransformState &state) { ArrayRef tileSizes = getStaticSizes(); @@ -2478,8 +2457,6 @@ } tilingOptions.setInterchange(getInterchange()); - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); FailureOr maybeTilingResult = tileUsingSCFForOp(rewriter, tilingInterface, tilingOptions); if (failed(maybeTilingResult)) @@ -2720,45 +2697,44 @@ } DiagnosedSilenceableFailure -transform::TileToForallOp::apply(transform::TransformResults &transformResults, +transform::TileToForallOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &transformResults, transform::TransformState &state) { - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); - auto transformOp = cast(getOperation()); - - // Result payload ops. - SmallVector tileOps; - SmallVector tiledOps; - - // Unpack handles. - SmallVector mixedNumThreads; - DiagnosedSilenceableFailure status = - getPackedNumThreads() - ? unpackSingleIndexResultPayloadOperations( - state, transformOp, mixedNumThreads, getPackedNumThreads()) - : unpackSingleIndexResultPayloadOperations( - state, transformOp, mixedNumThreads, getMixedNumThreads()); - if (!status.succeeded()) - return status; - SmallVector mixedTileSizes; - status = getPackedTileSizes() - ? unpackSingleIndexResultPayloadOperations( - state, transformOp, mixedTileSizes, getPackedTileSizes()) - : unpackSingleIndexResultPayloadOperations( - state, transformOp, mixedTileSizes, getMixedTileSizes()); - if (!status.succeeded()) - return status; - - for (Operation *target : state.getPayloadOps(getTarget())) { - linalg::ForallTilingResult tilingResult; - DiagnosedSilenceableFailure diag = tileToForallOpImpl( - rewriter, state, transformOp, target, mixedNumThreads, mixedTileSizes, - getMapping(), tilingResult); - if (!diag.succeeded()) + auto transformOp = cast(getOperation()); + + // Result payload ops. + SmallVector tileOps; + SmallVector tiledOps; + + // Unpack handles. + SmallVector mixedNumThreads; + DiagnosedSilenceableFailure status = + getPackedNumThreads() + ? unpackSingleIndexResultPayloadOperations( + state, transformOp, mixedNumThreads, getPackedNumThreads()) + : unpackSingleIndexResultPayloadOperations( + state, transformOp, mixedNumThreads, getMixedNumThreads()); + if (!status.succeeded()) + return status; + SmallVector mixedTileSizes; + status = getPackedTileSizes() + ? unpackSingleIndexResultPayloadOperations( + state, transformOp, mixedTileSizes, getPackedTileSizes()) + : unpackSingleIndexResultPayloadOperations( + state, transformOp, mixedTileSizes, getMixedTileSizes()); + if (!status.succeeded()) + return status; + + for (Operation *target : state.getPayloadOps(getTarget())) { + linalg::ForallTilingResult tilingResult; + DiagnosedSilenceableFailure diag = tileToForallOpImpl( + rewriter, state, transformOp, target, mixedNumThreads, mixedTileSizes, + getMapping(), tilingResult); + if (!diag.succeeded()) return diag; tileOps.push_back(tilingResult.tileOp); tiledOps.push_back(tilingResult.tiledOp); - } + } transformResults.set(cast(getForallOp()), tileOps); transformResults.set(cast(getTiledOp()), tiledOps); @@ -2833,7 +2809,8 @@ } DiagnosedSilenceableFailure -transform::TileToScfForOp::apply(TransformResults &transformResults, +transform::TileToScfForOp::apply(transform::TransformRewriter &rewriter, + TransformResults &transformResults, TransformState &state) { ArrayRef tileSizes = getStaticSizes(); @@ -2902,8 +2879,6 @@ } tilingOptions.setInterchange(getInterchange()); - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); FailureOr tilingResult = tileUsingSCFForOp(rewriter, tilingInterfaceOp, tilingOptions); if (failed(tilingResult)) @@ -3047,7 +3022,8 @@ } // namespace DiagnosedSilenceableFailure -transform::VectorizeOp::applyToOne(Operation *target, +transform::VectorizeOp::applyToOne(transform::TransformRewriter &rewriter, + Operation *target, transform::ApplyToEachResultList &results, transform::TransformState &state) { if (!target->hasTrait()) { @@ -3093,10 +3069,9 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::MaskedVectorizeOp::apply( + transform::TransformRewriter &rewriter, mlir::transform::TransformResults &transformResults, mlir::transform::TransformState &state) { - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); auto targets = state.getPayloadOps(getTarget()); if (std::empty(targets)) return DiagnosedSilenceableFailure::success(); @@ -3173,7 +3148,8 @@ DiagnosedSilenceableFailure transform::HoistRedundantVectorTransfersOp::applyToOne( - func::FuncOp target, transform::ApplyToEachResultList &results, + transform::TransformRewriter &rewriter, func::FuncOp target, + transform::ApplyToEachResultList &results, transform::TransformState &state) { // WARNING: This hoisting does not model parallelism and is generally // incorrect when used on distributed loops with memref semantics! @@ -3188,10 +3164,9 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::ConvertConv2DToImg2ColOp::applyToOne( - linalg::LinalgOp target, transform::ApplyToEachResultList &results, + transform::TransformRewriter &rewriter, linalg::LinalgOp target, + transform::ApplyToEachResultList &results, transform::TransformState &state) { - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); rewriter.setInsertionPoint(target); auto maybeTransformed = TypeSwitch>>( @@ -3223,10 +3198,9 @@ DiagnosedSilenceableFailure transform::HoistRedundantTensorSubsetsOp::applyToOne( - Operation *target, transform::ApplyToEachResultList &results, + transform::TransformRewriter &rewriter, Operation *target, + transform::ApplyToEachResultList &results, transform::TransformState &state) { - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); auto forOp = dyn_cast(target); if (forOp) { linalg::hoistRedundantSubsetExtractInsert(rewriter, forOp); @@ -3289,11 +3263,10 @@ } DiagnosedSilenceableFailure transform::InsertSliceToCopyOp::applyToOne( - Operation *targetOp, transform::ApplyToEachResultList &results, + transform::TransformRewriter &rewriter, Operation *targetOp, + transform::ApplyToEachResultList &results, transform::TransformState &state) { - TrackingListener listener(state, *this); - IRRewriter rewriter(targetOp->getContext(), &listener); rewriter.setInsertionPoint(targetOp); if (auto target = dyn_cast(targetOp)) return doit(rewriter, target, results, state); diff --git a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp --- a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp +++ b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp @@ -60,10 +60,10 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::MemRefMultiBufferOp::apply( + transform::TransformRewriter &rewriter, transform::TransformResults &transformResults, transform::TransformState &state) { SmallVector results; - IRRewriter rewriter(getContext()); for (Operation *op : state.getPayloadOps(getTarget())) { bool canApplyMultiBuffer = true; auto target = cast(op); @@ -105,7 +105,8 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::MemRefMakeLoopIndependentOp::applyToOne( - Operation *target, transform::ApplyToEachResultList &results, + transform::TransformRewriter &rewriter, Operation *target, + transform::ApplyToEachResultList &results, transform::TransformState &state) { // Gather IVs. SmallVector ivs; @@ -123,7 +124,6 @@ } // Rewrite IR. - IRRewriter rewriter(target->getContext()); FailureOr replacement = failure(); if (auto allocaOp = dyn_cast(target)) { replacement = memref::replaceWithIndependentOp(rewriter, allocaOp, ivs); diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp --- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp +++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp @@ -35,7 +35,8 @@ // GetParentForOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::GetParentForOp::apply(transform::TransformResults &results, +transform::GetParentForOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { SetVector parents; for (Operation *target : state.getPayloadOps(getTarget())) { @@ -92,7 +93,8 @@ } DiagnosedSilenceableFailure -transform::LoopOutlineOp::apply(transform::TransformResults &results, +transform::LoopOutlineOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { SmallVector functions; SmallVector calls; @@ -100,7 +102,6 @@ for (Operation *target : state.getPayloadOps(getTarget())) { Location location = target->getLoc(); Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(target); - IRRewriter rewriter(getContext()); scf::ExecuteRegionOp exec = wrapInExecuteRegion(rewriter, target); if (!exec) { DiagnosedSilenceableFailure diag = emitSilenceableError() @@ -135,11 +136,11 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::LoopPeelOp::applyToOne(scf::ForOp target, +transform::LoopPeelOp::applyToOne(transform::TransformRewriter &rewriter, + scf::ForOp target, transform::ApplyToEachResultList &results, transform::TransformState &state) { scf::ForOp result; - IRRewriter rewriter(target->getContext()); // This helper returns failure when peeling does not occur (i.e. when the IR // is not modified). This is not a failure for the op as the postcondition: // "the loop trip count is divisible by the step" @@ -192,7 +193,8 @@ } DiagnosedSilenceableFailure -transform::LoopPipelineOp::applyToOne(scf::ForOp target, +transform::LoopPipelineOp::applyToOne(transform::TransformRewriter &rewriter, + scf::ForOp target, transform::ApplyToEachResultList &results, transform::TransformState &state) { scf::PipeliningOption options; @@ -203,7 +205,6 @@ getReadLatency()); }; scf::ForLoopPipeliningPattern pattern(options, target->getContext()); - IRRewriter rewriter(getContext()); rewriter.setInsertionPoint(target); FailureOr patternResult = scf::pipelineForLoop(rewriter, target, options); @@ -219,7 +220,8 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::LoopUnrollOp::applyToOne(Operation *op, +transform::LoopUnrollOp::applyToOne(transform::TransformRewriter &rewriter, + Operation *op, transform::ApplyToEachResultList &results, transform::TransformState &state) { LogicalResult result(failure()); @@ -241,7 +243,8 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::LoopCoalesceOp::applyToOne(Operation *op, +transform::LoopCoalesceOp::applyToOne(transform::TransformRewriter &rewriter, + Operation *op, transform::ApplyToEachResultList &results, transform::TransformState &state) { LogicalResult result(failure()); @@ -276,12 +279,10 @@ } DiagnosedSilenceableFailure transform::TakeAssumedBranchOp::applyToOne( - scf::IfOp ifOp, transform::ApplyToEachResultList &results, + transform::TransformRewriter &rewriter, scf::IfOp ifOp, + transform::ApplyToEachResultList &results, transform::TransformState &state) { - TrackingListener listener(state, *this); - IRRewriter rewriter(ifOp->getContext(), &listener); rewriter.setInsertionPoint(ifOp); - Region ®ion = getTakeElseBranch() ? ifOp.getElseRegion() : ifOp.getThenRegion(); if (!llvm::hasSingleElement(region)) { diff --git a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp --- a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp +++ b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp @@ -118,7 +118,8 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::MakeLoopIndependentOp::applyToOne( - Operation *target, transform::ApplyToEachResultList &results, + transform::TransformRewriter &rewriter, Operation *target, + transform::ApplyToEachResultList &results, transform::TransformState &state) { // Gather IVs. SmallVector ivs; @@ -136,7 +137,6 @@ } // Rewrite IR. - IRRewriter rewriter(target->getContext()); FailureOr replacement = failure(); if (auto padOp = dyn_cast(target)) { replacement = tensor::buildIndependentOp(rewriter, padOp, ivs); diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp @@ -896,12 +896,31 @@ return diag; } + // Prepare rewriter and listener. + transform::ErrorCheckingTrackingListener trackingListener(*this, transform); + transform::TransformRewriter rewriter(transform->getContext(), *this); + rewriter.setListener(&trackingListener); + // Compute the result but do not short-circuit the silenceable failure case as // we still want the handles to propagate properly so the "suppress" mode can // proceed on a best effort basis. transform::TransformResults results(transform->getNumResults()); - DiagnosedSilenceableFailure result(transform.apply(results, *this)); + DiagnosedSilenceableFailure result(transform.apply(rewriter, results, *this)); compactOpHandles(); + + // Error handling: fail if transform or listener failed. + DiagnosedSilenceableFailure trackingFailure = + trackingListener.checkAndResetError(); + if (!trackingFailure.succeeded()) { + if (result.succeeded()) { + result = std::move(trackingFailure); + } else { + // Transform op errors have precedence, report those first. + result.attachNote() << "tracking listener also failed: " + << trackingFailure.getMessage(); + (void)trackingFailure.silence(); + } + } if (result.isDefiniteFailure()) return result; @@ -1161,6 +1180,14 @@ // TrackingListener //===----------------------------------------------------------------------===// +transform::TrackingListener::TrackingListener(TransformState &state, + TransformOpInterface op) + : TransformState::Extension(state), transformOp(op) { + if (op) + for (OpOperand *opOperand : transformOp.getConsumedHandleOpOperands()) + consumedHandles.insert(opOperand->get()); +} + Operation *transform::TrackingListener::getCommonDefiningOp(ValueRange values) { Operation *defOp = nullptr; for (Value v : values) { @@ -1267,15 +1294,34 @@ // Op is not tracked. return; } + + // Helper function to check if the current transform op consumes any handle + // that is mapped to `op`. + // + // Note: If a handle was consumed, there shouldn't be any alive users, so it + // is not really necessary to check for consumed handles. However, in case + // there are indeed alive handles that were consumed (which is invalid IR) and + // a replacement op could not be found, we want to fail with a nicer error + // message: "op uses a handle invalidated..." instead of "could not find + // replacement op". This nicer error is produced later. + auto handleWasConsumed = [&] { + return llvm::any_of(opHandles, + [&](Value h) { return consumedHandles.contains(h); }); + }; + + // Helper function to check if the handle is alive. auto hasAliveUser = [&]() { - for (Value v : opHandles) + for (Value v : opHandles) { for (Operation *user : v.getUsers()) - if (!happensBefore(user, transformOp)) + if (user != transformOp && !happensBefore(user, transformOp)) return true; + } return false; }; - if (!hasAliveUser()) { - // The op is tracked but the corresponding handles are dead. + + if (!hasAliveUser() || handleWasConsumed()) { + // The op is tracked but the corresponding handles are dead or were + // consumed. Drop the op form the mapping. (void)replacePayloadOp(op, nullptr); return; } @@ -1326,6 +1372,14 @@ ++errorCounter; } +//===----------------------------------------------------------------------===// +// TransformRewriter +//===----------------------------------------------------------------------===// + +transform::TransformRewriter::TransformRewriter(MLIRContext *ctx, + TransformState &state) + : RewriterBase(ctx), TransformState::Extension(state) {} + //===----------------------------------------------------------------------===// // Utilities for TransformEachOpTrait. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -1,4 +1,4 @@ -//===- TransformDialect.cpp - Transform dialect operations ----------------===// +//===- TransformOps.cpp - Transform dialect operations --------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Transform/IR/TransformOps.h" + #include "mlir/Dialect/Transform/IR/MatchInterfaces.h" #include "mlir/Dialect/Transform/IR/TransformAttrs.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" @@ -92,7 +93,8 @@ } DiagnosedSilenceableFailure -transform::AlternativesOp::apply(transform::TransformResults &results, +transform::AlternativesOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { SmallVector originals; if (Value scopeHandle = getScope()) @@ -199,7 +201,8 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::AnnotateOp::apply(transform::TransformResults &results, +transform::AnnotateOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { SmallVector targets = llvm::to_vector(state.getPayloadOps(getTarget())); @@ -235,10 +238,9 @@ // ApplyPatternsOp //===----------------------------------------------------------------------===// -DiagnosedSilenceableFailure -transform::ApplyPatternsOp::applyToOne(Operation *target, - ApplyToEachResultList &results, - transform::TransformState &state) { +DiagnosedSilenceableFailure transform::ApplyPatternsOp::applyToOne( + transform::TransformRewriter &rewriter, Operation *target, + ApplyToEachResultList &results, transform::TransformState &state) { // Gather all specified patterns. MLIRContext *ctx = target->getContext(); RewritePatternSet patterns(ctx); @@ -346,7 +348,8 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::CastOp::applyToOne(Operation *target, ApplyToEachResultList &results, +transform::CastOp::applyToOne(transform::TransformRewriter &rewriter, + Operation *target, ApplyToEachResultList &results, transform::TransformState &state) { results.push_back(target); return DiagnosedSilenceableFailure::success(); @@ -408,7 +411,8 @@ } DiagnosedSilenceableFailure -transform::ForeachMatchOp::apply(transform::TransformResults &results, +transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { SmallVector> matchActionPairs; @@ -706,7 +710,8 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::ForeachOp::apply(transform::TransformResults &results, +transform::ForeachOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { SmallVector> resultOps(getNumResults(), {}); @@ -795,6 +800,7 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::GetClosestIsolatedParentOp::apply( + transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { SetVector parents; for (Operation *target : state.getPayloadOps(getTarget())) { @@ -818,7 +824,8 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::GetConsumersOfResult::apply(transform::TransformResults &results, +transform::GetConsumersOfResult::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { int64_t resultNumber = getResultNumber(); auto payloadOps = state.getPayloadOps(getTarget()); @@ -843,7 +850,8 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::GetDefiningOp::apply(transform::TransformResults &results, +transform::GetDefiningOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { SmallVector definingOps; for (Value v : state.getPayloadValues(getTarget())) { @@ -864,7 +872,8 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::GetProducerOfOperand::apply(transform::TransformResults &results, +transform::GetProducerOfOperand::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { int64_t operandNumber = getOperandNumber(); SmallVector producers; @@ -892,7 +901,8 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::GetResultOp::apply(transform::TransformResults &results, +transform::GetResultOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { int64_t resultNumber = getResultNumber(); SmallVector opResults; @@ -943,7 +953,8 @@ } DiagnosedSilenceableFailure -transform::IncludeOp::apply(transform::TransformResults &results, +transform::IncludeOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { auto callee = SymbolTable::lookupNearestSymbolFrom( getOperation(), getTarget()); @@ -1081,7 +1092,8 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::MatchParamCmpIOp::apply(transform::TransformResults &results, +transform::MatchParamCmpIOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { auto signedAPIntAsString = [&](APInt value) { std::string str; @@ -1167,7 +1179,8 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::ParamConstantOp::apply(transform::TransformResults &results, +transform::ParamConstantOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { results.setParams(cast(getParam()), {getValue()}); return DiagnosedSilenceableFailure::success(); @@ -1178,7 +1191,8 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::MergeHandlesOp::apply(transform::TransformResults &results, +transform::MergeHandlesOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { SmallVector operations; for (Value operand : getHandles()) @@ -1221,7 +1235,8 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::NamedSequenceOp::apply(transform::TransformResults &results, +transform::NamedSequenceOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { // Nothing to do here. return DiagnosedSilenceableFailure::success(); @@ -1387,7 +1402,8 @@ } DiagnosedSilenceableFailure -transform::SplitHandleOp::apply(transform::TransformResults &results, +transform::SplitHandleOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { int64_t numPayloadOps = llvm::range_size(state.getPayloadOps(getHandle())); auto produceNumOpsError = [&]() { @@ -1448,7 +1464,8 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::ReplicateOp::apply(transform::TransformResults &results, +transform::ReplicateOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { unsigned numRepetitions = llvm::range_size(state.getPayloadOps(getPattern())); for (const auto &en : llvm::enumerate(getHandles())) { @@ -1488,7 +1505,8 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::SequenceOp::apply(transform::TransformResults &results, +transform::SequenceOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { // Map the entry block argument to the list of operations. auto scope = state.make_region_scope(*getBodyBlock()->getParent()); @@ -1778,7 +1796,8 @@ } DiagnosedSilenceableFailure -transform::PrintOp::apply(transform::TransformResults &results, +transform::PrintOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { llvm::outs() << "[[[ IR printer: "; if (getName().has_value()) diff --git a/mlir/lib/Dialect/Transform/PDLExtension/PDLExtensionOps.cpp b/mlir/lib/Dialect/Transform/PDLExtension/PDLExtensionOps.cpp --- a/mlir/lib/Dialect/Transform/PDLExtension/PDLExtensionOps.cpp +++ b/mlir/lib/Dialect/Transform/PDLExtension/PDLExtensionOps.cpp @@ -138,7 +138,8 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::PDLMatchOp::apply(transform::TransformResults &results, +transform::PDLMatchOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { auto *extension = state.getExtension(); assert(extension && @@ -167,7 +168,8 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::WithPDLPatternsOp::apply(transform::TransformResults &results, +transform::WithPDLPatternsOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { TransformOpInterface transformOp = nullptr; for (Operation &nested : getBody().front()) { diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp @@ -45,7 +45,8 @@ return llvm::StringLiteral("transform.test_transform_op"); } - DiagnosedSilenceableFailure apply(transform::TransformResults &results, + DiagnosedSilenceableFailure apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { InFlightDiagnostic remark = emitRemark() << "applying transformation"; if (Attribute message = getMessage()) @@ -98,7 +99,8 @@ "transform.test_transform_unrestricted_op_no_interface"); } - DiagnosedSilenceableFailure apply(transform::TransformResults &results, + DiagnosedSilenceableFailure apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { return DiagnosedSilenceableFailure::success(); } @@ -110,6 +112,7 @@ DiagnosedSilenceableFailure mlir::test::TestProduceSelfHandleOrForwardOperandOp::apply( + transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { if (getOperation()->getNumOperands() != 0) { results.set(cast(getResult()), @@ -129,6 +132,7 @@ DiagnosedSilenceableFailure mlir::test::TestProduceValueHandleToSelfOperand::apply( + transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { results.setValues(llvm::cast(getOut()), getIn()); return DiagnosedSilenceableFailure::success(); @@ -143,7 +147,8 @@ DiagnosedSilenceableFailure mlir::test::TestProduceValueHandleToResult::applyToOne( - Operation *target, transform::ApplyToEachResultList &results, + transform::TransformRewriter &rewriter, Operation *target, + transform::ApplyToEachResultList &results, transform::TransformState &state) { if (target->getNumResults() <= getNumber()) return emitSilenceableError() << "payload has no result #" << getNumber(); @@ -160,7 +165,8 @@ DiagnosedSilenceableFailure mlir::test::TestProduceValueHandleToArgumentOfParentBlock::applyToOne( - Operation *target, transform::ApplyToEachResultList &results, + transform::TransformRewriter &rewriter, Operation *target, + transform::ApplyToEachResultList &results, transform::TransformState &state) { if (!target->getBlock()) return emitSilenceableError() << "payload has no parent block"; @@ -183,7 +189,8 @@ } DiagnosedSilenceableFailure -mlir::test::TestConsumeOperand::apply(transform::TransformResults &results, +mlir::test::TestConsumeOperand::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { return DiagnosedSilenceableFailure::success(); } @@ -197,6 +204,7 @@ } DiagnosedSilenceableFailure mlir::test::TestConsumeOperandOfOpKindOrFail::apply( + transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { auto payload = state.getPayloadOps(getOperand()); assert(llvm::hasSingleElement(payload) && "expected a single target op"); @@ -237,6 +245,7 @@ } DiagnosedSilenceableFailure mlir::test::TestPrintRemarkAtOperandOp::apply( + transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { auto payload = state.getPayloadOps(getOperand()); for (Operation *op : payload) @@ -252,6 +261,7 @@ } DiagnosedSilenceableFailure mlir::test::TestPrintRemarkAtOperandValue::apply( + transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { ArrayRef values = state.getPayloadValues(getIn()); for (Value value : values) { @@ -277,15 +287,16 @@ transform::onlyReadsPayload(effects); } -DiagnosedSilenceableFailure -mlir::test::TestAddTestExtensionOp::apply(transform::TransformResults &results, - transform::TransformState &state) { +DiagnosedSilenceableFailure mlir::test::TestAddTestExtensionOp::apply( + transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { state.addExtension(getMessageAttr()); return DiagnosedSilenceableFailure::success(); } DiagnosedSilenceableFailure mlir::test::TestCheckIfTestExtensionPresentOp::apply( + transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { auto *extension = state.getExtension(); if (!extension) { @@ -316,6 +327,7 @@ } DiagnosedSilenceableFailure mlir::test::TestRemapOperandPayloadToSelfOp::apply( + transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { auto *extension = state.getExtension(); if (!extension) @@ -337,14 +349,15 @@ } DiagnosedSilenceableFailure mlir::test::TestRemoveTestExtensionOp::apply( + transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { state.removeExtension(); return DiagnosedSilenceableFailure::success(); } -DiagnosedSilenceableFailure -mlir::test::TestReversePayloadOpsOp::apply(transform::TransformResults &results, - transform::TransformState &state) { +DiagnosedSilenceableFailure mlir::test::TestReversePayloadOpsOp::apply( + transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { auto payloadOps = state.getPayloadOps(getTarget()); auto reversedOps = llvm::to_vector(llvm::reverse(payloadOps)); results.set(llvm::cast(getResult()), reversedOps); @@ -352,6 +365,7 @@ } DiagnosedSilenceableFailure mlir::test::TestTransformOpWithRegions::apply( + transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { return DiagnosedSilenceableFailure::success(); } @@ -361,6 +375,7 @@ DiagnosedSilenceableFailure mlir::test::TestBranchingTransformOpTerminator::apply( + transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { return DiagnosedSilenceableFailure::success(); } @@ -369,6 +384,7 @@ SmallVectorImpl &effects) {} DiagnosedSilenceableFailure mlir::test::TestEmitRemarkAndEraseOperandOp::apply( + transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { emitRemark() << getRemark(); for (Operation *op : state.getPayloadOps(getTarget())) @@ -386,7 +402,8 @@ } DiagnosedSilenceableFailure mlir::test::TestWrongNumberOfResultsOp::applyToOne( - Operation *target, transform::ApplyToEachResultList &results, + transform::TransformRewriter &rewriter, Operation *target, + transform::ApplyToEachResultList &results, transform::TransformState &state) { OperationState opState(target->getLoc(), "foo"); results.push_back(OpBuilder(target).create(opState)); @@ -395,7 +412,8 @@ DiagnosedSilenceableFailure mlir::test::TestWrongNumberOfMultiResultsOp::applyToOne( - Operation *target, transform::ApplyToEachResultList &results, + transform::TransformRewriter &rewriter, Operation *target, + transform::ApplyToEachResultList &results, transform::TransformState &state) { static int count = 0; if (count++ == 0) { @@ -407,7 +425,8 @@ DiagnosedSilenceableFailure mlir::test::TestCorrectNumberOfMultiResultsOp::applyToOne( - Operation *target, transform::ApplyToEachResultList &results, + transform::TransformRewriter &rewriter, Operation *target, + transform::ApplyToEachResultList &results, transform::TransformState &state) { OperationState opState(target->getLoc(), "foo"); results.push_back(OpBuilder(target).create(opState)); @@ -417,7 +436,8 @@ DiagnosedSilenceableFailure mlir::test::TestMixedNullAndNonNullResultsOp::applyToOne( - Operation *target, transform::ApplyToEachResultList &results, + transform::TransformRewriter &rewriter, Operation *target, + transform::ApplyToEachResultList &results, transform::TransformState &state) { OperationState opState(target->getLoc(), "foo"); results.push_back(nullptr); @@ -427,7 +447,8 @@ DiagnosedSilenceableFailure mlir::test::TestMixedSuccessAndSilenceableOp::applyToOne( - Operation *target, transform::ApplyToEachResultList &results, + transform::TransformRewriter &rewriter, Operation *target, + transform::ApplyToEachResultList &results, transform::TransformState &state) { if (target->hasAttr("target_me")) return DiagnosedSilenceableFailure::success(); @@ -436,6 +457,7 @@ DiagnosedSilenceableFailure mlir::test::TestPrintNumberOfAssociatedPayloadIROps::apply( + transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { if (!getHandle()) emitRemark() << 0; @@ -449,7 +471,8 @@ } DiagnosedSilenceableFailure -mlir::test::TestCopyPayloadOp::apply(transform::TransformResults &results, +mlir::test::TestCopyPayloadOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { results.set(llvm::cast(getCopy()), state.getPayloadOps(getHandle())); @@ -498,6 +521,7 @@ DiagnosedSilenceableFailure mlir::test::TestReportNumberOfTrackedHandlesNestedUnder::apply( + transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { int64_t count = 0; for (Operation *op : state.getPayloadOps(getTarget())) { @@ -520,7 +544,8 @@ } DiagnosedSilenceableFailure -mlir::test::TestPrintParamOp::apply(transform::TransformResults &results, +mlir::test::TestPrintParamOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { std::string str; llvm::raw_string_ostream os(str); @@ -537,7 +562,8 @@ } DiagnosedSilenceableFailure -mlir::test::TestAddToParamOp::apply(transform::TransformResults &results, +mlir::test::TestAddToParamOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { SmallVector values(/*Size=*/1, /*Value=*/0); if (Value param = getParam()) { @@ -559,6 +585,7 @@ DiagnosedSilenceableFailure mlir::test::TestProduceParamWithNumberOfTestOps::apply( + transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { Builder builder(getContext()); SmallVector result = llvm::to_vector( @@ -577,6 +604,7 @@ DiagnosedSilenceableFailure mlir::test::TestProduceIntegerParamWithTypeOp::apply( + transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { Attribute zero = IntegerAttr::get(getType(), 0); results.setParams(llvm::cast(getResult()), zero); @@ -599,7 +627,8 @@ DiagnosedSilenceableFailure mlir::test::TestProduceTransformParamOrForwardOperandOp::applyToOne( - Operation *target, ::transform::ApplyToEachResultList &results, + transform::TransformRewriter &rewriter, Operation *target, + ::transform::ApplyToEachResultList &results, ::transform::TransformState &state) { Builder builder(getContext()); if (getFirstResultIsParam()) { @@ -625,6 +654,7 @@ } DiagnosedSilenceableFailure mlir::test::TestProduceNullPayloadOp::apply( + transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { SmallVector null({nullptr}); results.set(llvm::cast(getOut()), null); @@ -632,6 +662,7 @@ } DiagnosedSilenceableFailure mlir::test::TestProduceEmptyPayloadOp::apply( + transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { results.set(cast(getOut()), {}); return DiagnosedSilenceableFailure::success(); @@ -642,9 +673,9 @@ transform::producesHandle(getOut(), effects); } -DiagnosedSilenceableFailure -mlir::test::TestProduceNullParamOp::apply(transform::TransformResults &results, - transform::TransformState &state) { +DiagnosedSilenceableFailure mlir::test::TestProduceNullParamOp::apply( + transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { results.setParams(llvm::cast(getOut()), Attribute()); return DiagnosedSilenceableFailure::success(); } @@ -654,9 +685,9 @@ transform::producesHandle(getOut(), effects); } -DiagnosedSilenceableFailure -mlir::test::TestProduceNullValueOp::apply(transform::TransformResults &results, - transform::TransformState &state) { +DiagnosedSilenceableFailure mlir::test::TestProduceNullValueOp::apply( + transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { results.setValues(llvm::cast(getOut()), Value()); return DiagnosedSilenceableFailure::success(); } @@ -676,6 +707,7 @@ } DiagnosedSilenceableFailure mlir::test::TestRequiredMemoryEffectsOp::apply( + transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { results.set(llvm::cast(getOut()), state.getPayloadOps(getIn())); return DiagnosedSilenceableFailure::success(); @@ -711,17 +743,20 @@ } // namespace DiagnosedSilenceableFailure -mlir::test::TestTrackedRewriteOp::apply(transform::TransformResults &results, +mlir::test::TestTrackedRewriteOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { TestTrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); + // Use a custom rewriter, so that a specialized listener (for this test) can + // be attached. + IRRewriter rewriterUnderTest(getContext(), &listener); int64_t numIterations = 0; // `getPayloadOps` returns an iterator that skips ops that are erased in the // loop body. Replacement ops are not enumerated. for (Operation *op : state.getPayloadOps(getIn())) { ++numIterations; - rewriter.setInsertionPointToEnd(op->getBlock()); + rewriterUnderTest.setInsertionPointToEnd(op->getBlock()); // Erase all payload ops. The outer loop should have only one iteration. for (Operation *op : state.getPayloadOps(getIn())) { @@ -736,8 +771,8 @@ OperationState opState(op->getLoc(), replacementName, /*operands=*/ValueRange(), /*types=*/op->getResultTypes(), attributes); - Operation *newOp = rewriter.create(opState); - rewriter.replaceOp(op, newOp->getResults()); + Operation *newOp = rewriterUnderTest.create(opState); + rewriterUnderTest.replaceOp(op, newOp->getResults()); } } diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td @@ -72,6 +72,7 @@ let cppNamespace = "::mlir::test"; let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::Operation *target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -89,6 +90,7 @@ let cppNamespace = "::mlir::test"; let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::Operation *target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -244,6 +246,7 @@ let cppNamespace = "::mlir::test"; let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::Operation * target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -261,6 +264,7 @@ let cppNamespace = "::mlir::test"; let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::Operation * target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -279,6 +283,7 @@ let cppNamespace = "::mlir::test"; let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::Operation * target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -297,6 +302,7 @@ let cppNamespace = "::mlir::test"; let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::Operation * target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -313,6 +319,7 @@ let cppNamespace = "::mlir::test"; let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::Operation * target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -412,6 +419,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::Operation *target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state);