diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/TensorCopyInsertion.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/TensorCopyInsertion.h deleted file mode 100644 --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/TensorCopyInsertion.h +++ /dev/null @@ -1,26 +0,0 @@ -//===- TensorCopyInsertion.h - Resolve Bufferization Conflicts w/ Copies --===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_TENSORCOPYINSERTION_H -#define MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_TENSORCOPYINSERTION_H - -#include "mlir/IR/Operation.h" - -namespace mlir { -namespace bufferization { -class AnalysisState; -struct OneShotBufferizationOptions; - -LogicalResult insertTensorCopies(Operation *op, - const OneShotBufferizationOptions &options); - -LogicalResult insertTensorCopies(Operation *op, const AnalysisState &state); -} // namespace bufferization -} // namespace mlir - -#endif // MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_TENSORCOPYINSERTION_H diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/EmptyTensorElimination.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h rename from mlir/include/mlir/Dialect/Bufferization/Transforms/EmptyTensorElimination.h rename to mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/EmptyTensorElimination.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h @@ -1,4 +1,4 @@ -//===- EmptyTensorElimination.h - tensor.empty op elimination -------------===// +//===- Transforms.h - Bufferization and related transforms ------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,13 +6,16 @@ // //===----------------------------------------------------------------------===// -#ifndef MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_EMPTYTENSORELIMINATION_H -#define MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_EMPTYTENSORELIMINATION_H +#ifndef MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_TRANSFORMS_H +#define MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_TRANSFORMS_H #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" +#include "mlir/IR/Operation.h" namespace mlir { namespace bufferization { +class AnalysisState; +struct OneShotBufferizationOptions; /// A function that matches anchor OpOperands for tensor::EmptyOp elimination. /// If an OpOperand is matched, the function should populate the SmallVector @@ -42,7 +45,22 @@ LogicalResult insertSliceAnchoredEmptyTensorEliminationStep( RewriterBase &rewriter, Operation *op, bufferization::AnalysisState &state); +/// Resolve RaW and other conflicts by inserting bufferization.alloc_tensor ops. +/// After applying this transform, the IR can be bufferized without inserting +/// additional buffer allocations. +LogicalResult insertTensorCopies(Operation *op, + const OneShotBufferizationOptions &options); + +/// Resolve RaW and other conflicts by inserting bufferization.alloc_tensor ops. +/// After applying this transform, the IR can be bufferized without inserting +/// additional buffer allocations. +LogicalResult insertTensorCopies(Operation *op, const AnalysisState &state); + +/// Populate patterns to lower tensor.empty ops to bufferization.alloc_tensor +/// ops. +void populateEmptyTensorToAllocTensorPattern(RewritePatternSet &patterns); + } // namespace bufferization } // namespace mlir -#endif // MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_EMPTYTENSORELIMINATION_H +#endif // MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_TRANSFORMS_H diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp @@ -13,7 +13,7 @@ #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" #include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h" -#include "mlir/Dialect/Bufferization/Transforms/TensorCopyInsertion.h" +#include "mlir/Dialect/Bufferization/Transforms/Transforms.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/Operation.h" diff --git a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp @@ -10,8 +10,8 @@ #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" -#include "mlir/Dialect/Bufferization/Transforms/EmptyTensorElimination.h" #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" +#include "mlir/Dialect/Bufferization/Transforms/Transforms.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Dominance.h" #include "mlir/Pass/Pass.h" diff --git a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorToAllocTensor.cpp b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorToAllocTensor.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorToAllocTensor.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorToAllocTensor.cpp @@ -9,6 +9,7 @@ #include "mlir/Dialect/Bufferization/Transforms/Passes.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Bufferization/Transforms/Transforms.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -50,10 +51,15 @@ }; } // namespace +void bufferization::populateEmptyTensorToAllocTensorPattern( + RewritePatternSet &patterns) { + patterns.insert(patterns.getContext()); +} + void EmptyTensorToAllocTensor::runOnOperation() { Operation *op = getOperation(); RewritePatternSet patterns(op->getContext()); - patterns.insert(op->getContext()); + populateEmptyTensorToAllocTensorPattern(patterns); if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) signalPassFailure(); } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp @@ -46,7 +46,7 @@ #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" -#include "mlir/Dialect/Bufferization/Transforms/TensorCopyInsertion.h" +#include "mlir/Dialect/Bufferization/Transforms/Transforms.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/AsmState.h" diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp @@ -64,7 +64,7 @@ #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" #include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" -#include "mlir/Dialect/Bufferization/Transforms/TensorCopyInsertion.h" +#include "mlir/Dialect/Bufferization/Transforms/Transforms.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/Operation.h" diff --git a/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp b/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp @@ -13,7 +13,7 @@ #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" #include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h" -#include "mlir/Dialect/Bufferization/Transforms/TensorCopyInsertion.h" +#include "mlir/Dialect/Bufferization/Transforms/Transforms.h" #include "mlir/Dialect/Func/IR/FuncOps.h" namespace mlir {