diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransforms.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransforms.h --- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransforms.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransforms.h @@ -89,7 +89,7 @@ /// Implements transfer op write to read forwarding and dead transfer write /// optimizations. -void transferOpflowOpt(FuncOp func); +void transferOpflowOpt(Operation *rootOp); } // namespace vector } // namespace mlir diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp @@ -40,7 +40,7 @@ class TransferOptimization { public: - TransferOptimization(FuncOp func) : dominators(func), postDominators(func) {} + TransferOptimization(Operation *op) : dominators(op), postDominators(op) {} void deadStoreOp(vector::TransferWriteOp); void storeToLoadForwarding(vector::TransferReadOp); void removeDeadOp() { @@ -462,16 +462,16 @@ } // namespace -void mlir::vector::transferOpflowOpt(FuncOp func) { - TransferOptimization opt(func); +void mlir::vector::transferOpflowOpt(Operation *rootOp) { + TransferOptimization opt(rootOp); // Run store to load forwarding first since it can expose more dead store // opportunity. - func.walk([&](vector::TransferReadOp read) { + rootOp->walk([&](vector::TransferReadOp read) { if (read.getShapedType().isa()) opt.storeToLoadForwarding(read); }); opt.removeDeadOp(); - func.walk([&](vector::TransferWriteOp write) { + rootOp->walk([&](vector::TransferWriteOp write) { if (write.getShapedType().isa()) opt.deadStoreOp(write); });