diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -1957,8 +1957,9 @@ def HoistRedundantTensorSubsetsOp : Op { + [DeclareOpInterfaceMethods, + TransformEachOpTrait, + TransformOpInterface]> { let description = [{ Hoists supported tensor subset extract/insert operation pairs out of immediately enclosing loop iteratively, if the following conditions @@ -1978,18 +1979,18 @@ #### Return modes: - The operation always succeeds and returns a handle to the transformed - function op. + The operation always succeeds and returns nothing. }]; let arguments = (ins TransformHandleTypeInterface:$target); - let results = (outs TransformHandleTypeInterface:$transformed); + let results = (outs); - let assemblyFormat = "$target attr-dict `:` functional-type(operands, results) "; + let assemblyFormat = [{ + $target + attr-dict + `:` functional-type(operands, results) + }]; - let builders = [ - OpBuilder<(ins "Value":$target)>, - ]; let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( ::mlir::Operation *target, diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -3194,12 +3194,11 @@ transform::HoistRedundantTensorSubsetsOp::applyToOne( Operation *target, transform::ApplyToEachResultList &results, transform::TransformState &state) { - IRRewriter rewriter(target->getContext()); + TrackingListener listener(state, *this); + IRRewriter rewriter(target->getContext(), &listener); auto forOp = dyn_cast(target); if (forOp) { - scf::ForOp newForOp = - linalg::hoistRedundantSubsetExtractInsert(rewriter, forOp); - results.push_back(newForOp); + linalg::hoistRedundantSubsetExtractInsert(rewriter, forOp); return DiagnosedSilenceableFailure::success(); } @@ -3208,10 +3207,15 @@ target->walk([&](scf::ForOp forOp) { hoistRedundantSubsetExtractInsert(rewriter, forOp); }); - results.push_back(target); return DiagnosedSilenceableFailure::success(); } +void transform::HoistRedundantTensorSubsetsOp::getEffects( + SmallVectorImpl &effects) { + transform::onlyReadsHandle(getTarget(), effects); + transform::modifiesPayload(effects); +} + //===----------------------------------------------------------------------===// // InsertSliceToCopyOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Linalg/hoisting.mlir b/mlir/test/Dialect/Linalg/hoisting.mlir --- a/mlir/test/Dialect/Linalg/hoisting.mlir +++ b/mlir/test/Dialect/Linalg/hoisting.mlir @@ -302,7 +302,7 @@ %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!pdl.operation) -> !pdl.operation transform.structured.hoist_redundant_tensor_subsets %0 - : (!pdl.operation) -> !pdl.operation + : (!pdl.operation) -> () } // ----- @@ -397,7 +397,7 @@ %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!pdl.operation) -> !pdl.operation transform.structured.hoist_redundant_tensor_subsets %0 - : (!pdl.operation) -> !pdl.operation + : (!pdl.operation) -> () } // ----- @@ -514,7 +514,7 @@ %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!pdl.operation) -> !pdl.operation transform.structured.hoist_redundant_tensor_subsets %0 - : (!pdl.operation) -> !pdl.operation + : (!pdl.operation) -> () } // ----- @@ -561,7 +561,7 @@ %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!pdl.operation) -> !pdl.operation transform.structured.hoist_redundant_tensor_subsets %0 - : (!pdl.operation) -> !pdl.operation + : (!pdl.operation) -> () } // ----- @@ -674,7 +674,7 @@ %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!pdl.operation) -> !pdl.operation transform.structured.hoist_redundant_tensor_subsets %0 - : (!pdl.operation) -> !pdl.operation + : (!pdl.operation) -> () } // -----