diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/AllocTensorElimination.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/EmptyTensorElimination.h rename from mlir/include/mlir/Dialect/Bufferization/Transforms/AllocTensorElimination.h rename to mlir/include/mlir/Dialect/Bufferization/Transforms/EmptyTensorElimination.h --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/AllocTensorElimination.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/EmptyTensorElimination.h @@ -1,4 +1,4 @@ -//===- AllocTensorElimination.h - alloc_tensor op elimination -------------===// +//===- EmptyTensorElimination.h - tensor.empty op elimination -------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,15 +6,15 @@ // //===----------------------------------------------------------------------===// -#ifndef MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_ALLOCTENSORELIMINATION_H -#define MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_ALLOCTENSORELIMINATION_H +#ifndef MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_EMPTYTENSORELIMINATION_H +#define MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_EMPTYTENSORELIMINATION_H #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" namespace mlir { namespace bufferization { -/// A function that matches anchor OpOperands for AllocTensorOp elimination. +/// A function that matches anchor OpOperands for tensor::EmptyOp elimination. /// If an OpOperand is matched, the function should populate the SmallVector /// with all values that are needed during `RewriteFn` to produce the /// replacement value. @@ -23,26 +23,26 @@ /// A function that rewrites matched anchors. using RewriteFn = std::function; -/// Try to eliminate AllocTensorOps inside `op`. +/// Try to eliminate tensor::EmptyOps inside `op`. /// -/// * `rewriteFunc` generates the replacement for the AllocTensorOp. -/// * Only AllocTensorOps that are anchored on a matching OpOperand as per +/// * `rewriteFunc` generates the replacement for the tensor::EmptyOp. +/// * Only tensor::EmptyOps that are anchored on a matching OpOperand as per /// `anchorMatchFunc` are considered. "Anchored" means that there is a path /// on the reverse SSA use-def chain, starting from the OpOperand and always /// following the aliasing OpOperand, that eventually ends at a single -/// AllocTensorOp. -LogicalResult eliminateAllocTensors(RewriterBase &rewriter, Operation *op, +/// tensor::EmptyOp. +LogicalResult eliminateEmptyTensors(RewriterBase &rewriter, Operation *op, bufferization::AnalysisState &state, AnchorMatchFn anchorMatchFunc, RewriteFn rewriteFunc); -/// Try to eliminate AllocTensorOps inside `op` that are anchored on an +/// Try to eliminate tensor::EmptyOps inside `op` that are anchored on an /// InsertSliceOp, i.e., if it is eventually inserted into another tensor /// (and some other conditions are met). -LogicalResult insertSliceAnchoredAllocTensorEliminationStep( +LogicalResult insertSliceAnchoredEmptyTensorEliminationStep( RewriterBase &rewriter, Operation *op, bufferization::AnalysisState &state); } // namespace bufferization } // namespace mlir -#endif // MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_ALLOCTENSORELIMINATION_H +#endif // MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_EMPTYTENSORELIMINATION_H diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h @@ -90,9 +90,9 @@ std::unique_ptr createPromoteBuffersToStackPass(std::function isSmallAlloc); -/// Create a pass that tries to eliminate alloc_tensor ops that are anchored on +/// Create a pass that tries to eliminate tensor.empty ops that are anchored on /// insert_slice ops. -std::unique_ptr createAllocTensorEliminationPass(); +std::unique_ptr createEmptyTensorEliminationPass(); /// Create a pass that bufferizes ops from the bufferization dialect. std::unique_ptr createBufferizationBufferizePass(); diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td @@ -371,16 +371,16 @@ let constructor = "mlir::bufferization::createTensorCopyInsertionPass()"; } -def AllocTensorElimination : Pass<"eliminate-alloc-tensors"> { - let summary = "Try to eliminate all alloc_tensor ops."; +def EmptyTensorElimination : Pass<"eliminate-empty-tensors"> { + let summary = "Try to eliminate all tensor.empty ops."; let description = [{ - This pass tries to eliminate all insert_slice op-anchored alloc_tensor ops. - I.e., when a value that is equivalent to an alloc_tensor op is inserted into + This pass tries to eliminate all insert_slice op-anchored tensor.empty ops. + I.e., when a value that is equivalent to an tensor.empty op is inserted into another tensor, this pass tries to rewrite the IR in such a way that the destination tensor of the insert_slice op is used directly instead of the - alloc_tensor result. + tensor.empty result. }]; - let constructor = "mlir::bufferization::createAllocTensorEliminationPass()"; + let constructor = "mlir::bufferization::createEmptyTensorEliminationPass()"; } #endif // MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_PASSES diff --git a/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt @@ -1,5 +1,4 @@ add_mlir_dialect_library(MLIRBufferizationTransforms - AllocTensorElimination.cpp Bufferize.cpp BufferDeallocation.cpp BufferOptimizations.cpp @@ -7,6 +6,7 @@ BufferUtils.cpp BufferViewFlowAnalysis.cpp DropEquivalentBufferResults.cpp + EmptyTensorElimination.cpp EmptyTensorToAllocTensor.cpp FuncBufferizableOpInterfaceImpl.cpp OneShotAnalysis.cpp diff --git a/mlir/lib/Dialect/Bufferization/Transforms/AllocTensorElimination.cpp b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp rename from mlir/lib/Dialect/Bufferization/Transforms/AllocTensorElimination.cpp rename to mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/AllocTensorElimination.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp @@ -1,4 +1,4 @@ -//===- AllocTensorElimination.cpp - alloc_tensor op elimination -----------===// +//===- EmptyTensorElimination.cpp - tensor.empty op elimination -----------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -10,7 +10,7 @@ #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" -#include "mlir/Dialect/Bufferization/Transforms/AllocTensorElimination.h" +#include "mlir/Dialect/Bufferization/Transforms/EmptyTensorElimination.h" #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Dominance.h" @@ -18,7 +18,7 @@ namespace mlir { namespace bufferization { -#define GEN_PASS_DEF_ALLOCTENSORELIMINATION +#define GEN_PASS_DEF_EMPTYTENSORELIMINATION #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc" } // namespace bufferization } // namespace mlir @@ -47,27 +47,27 @@ } /// Return true if the given `insertionPoint` dominates all uses of -/// `allocTensorOp`. +/// `emptyTensorOp`. static bool insertionPointDominatesUses(const DominanceInfo &domInfo, Operation *insertionPoint, - Operation *allocTensorOp) { - for (Operation *user : allocTensorOp->getUsers()) + Operation *emptyTensorOp) { + for (Operation *user : emptyTensorOp->getUsers()) if (!domInfo.dominates(insertionPoint, user)) return false; return true; } -/// Find a valid insertion point for a replacement of `allocTensorOp`, assuming +/// Find a valid insertion point for a replacement of `emptyTensorOp`, assuming /// that the replacement may use any value from `neededValues`. static Operation * -findValidInsertionPoint(Operation *allocTensorOp, +findValidInsertionPoint(Operation *emptyTensorOp, const SmallVector &neededValues) { DominanceInfo domInfo; - // Gather all possible insertion points: the location of `allocTensorOp` and + // Gather all possible insertion points: the location of `emptyTensorOp` and // right after the definition of each value in `neededValues`. SmallVector insertionPointCandidates; - insertionPointCandidates.push_back(allocTensorOp); + insertionPointCandidates.push_back(emptyTensorOp); for (Value val : neededValues) { // Note: The anchor op is using all of `neededValues`, so: // * in case of a block argument: There must be at least one op in the block @@ -90,7 +90,7 @@ neededValues)) continue; // Check if the insertion point is before all uses. - if (!insertionPointDominatesUses(domInfo, insertionPoint, allocTensorOp)) + if (!insertionPointDominatesUses(domInfo, insertionPoint, emptyTensorOp)) continue; return insertionPoint; } @@ -99,12 +99,12 @@ return nullptr; } -/// Try to eliminate AllocTensorOps inside `op`. An AllocTensorOp is replaced +/// Try to eliminate tensor::EmptyOps inside `op`. A tensor::EmptyOp is replaced /// with the result of `rewriteFunc` if it is anchored on a matching /// OpOperand. "Anchored" means that there is a path on the reverse SSA use-def /// chain, starting from the OpOperand and always following the aliasing -/// OpOperand, that eventually ends at a single AllocTensorOp. -LogicalResult mlir::bufferization::eliminateAllocTensors( +/// OpOperand, that eventually ends at a single tensor::EmptyOp. +LogicalResult mlir::bufferization::eliminateEmptyTensors( RewriterBase &rewriter, Operation *op, AnalysisState &state, AnchorMatchFn anchorMatchFunc, RewriteFn rewriteFunc) { OpBuilder::InsertionGuard g(rewriter); @@ -119,56 +119,40 @@ // Is this a matching OpOperand? if (!anchorMatchFunc(operand, neededValues)) continue; - SetVector maybeAllocTensor = - state.findValueInReverseUseDefChain(operand.get(), [&](Value val) { - // Continue traversal until this function returns true. - OpResult opResult = val.dyn_cast(); - if (!opResult) - return true; - SmallVector opOperands = - state.getAliasingOpOperand(opResult); - if (!llvm::all_of(opOperands, [&](OpOperand *operand) { - return state.isInPlace(*operand); - })) - return true; - // Only equivalent tensors are supported at the moment. - // TODO: Support cases such as extract_slice(alloc_tensor) - return !llvm::all_of(opOperands, [&](OpOperand *operand) { - return state.areEquivalentBufferizedValues(operand->get(), - opResult); - }); - }); + SetVector maybeEmptyTensor = state.findValueInReverseUseDefChain( + operand.get(), /*condition=*/[&](Value val) { return false; }, + /*followEquivalentOnly=*/true); // Replace only if the reverse use-def chain ends at exactly one - // AllocTensorOp. - if (maybeAllocTensor.size() != 1 || - !maybeAllocTensor.front().getDefiningOp()) + // tensor::EmptyOp. + if (maybeEmptyTensor.size() != 1 || + !maybeEmptyTensor.front().getDefiningOp()) return WalkResult::skip(); - Value allocTensor = maybeAllocTensor.front(); + Value emptyTensor = maybeEmptyTensor.front(); // Replace only if the types match. // TODO: This could be extended to support IR such as: - // %0 = bufferization.alloc_tensor : tensor<128xf32> + // %0 = tensor.empty() : tensor<128xf32> // %1 = "some_op"(%0) : (tensor<128xf32>) -> (tensor<128xf32>) // %2 = tensor.expand_shape %1 ... // %3 = tensor.insert_slice %2 into ... - if (allocTensor.getType() != operand.get().getType()) + if (emptyTensor.getType() != operand.get().getType()) return WalkResult::skip(); // Find a suitable insertion point. Operation *insertionPoint = - findValidInsertionPoint(allocTensor.getDefiningOp(), neededValues); + findValidInsertionPoint(emptyTensor.getDefiningOp(), neededValues); if (!insertionPoint) continue; - // Create a replacement for the AllocTensorOp. + // Create a replacement for the tensor::EmptyOp. rewriter.setInsertionPoint(insertionPoint); - Value replacement = rewriteFunc(rewriter, allocTensor.getLoc(), operand); + Value replacement = rewriteFunc(rewriter, emptyTensor.getLoc(), operand); if (!replacement) continue; - // Replace the AllocTensorOp. - rewriter.replaceOp(allocTensor.getDefiningOp(), replacement); + // Replace the tensor::EmptyOp. + rewriter.replaceOp(emptyTensor.getDefiningOp(), replacement); } // Advance to the next operation. @@ -178,34 +162,35 @@ return failure(status.wasInterrupted()); } -/// Try to eliminate AllocTensorOps inside `op`. An AllocTensorOp can be +/// Try to eliminate tensor::EmptyOps inside `op`. An tensor::EmptyOp can be /// eliminated if it is eventually inserted into another tensor (and some other /// conditions are met). /// /// E.g.: -/// %0 = linalg.alloc_tensor +/// %0 = tensor.empty() /// %1 = linalg.fill(%cst, %0) {inplace = [true]} /// %2 = tensor.insert_slice %1 into %t[10][20][1] /// -/// AllocTensorOp elimination will try to fill %t inplace instead of filling a +/// tensor::EmptyOp elimination will try to fill %t inplace instead of filling a /// new allocation %0 and inserting it into %t. This is done by replacing the -/// AllocTensorOp with: +/// tensor::EmptyOp with: /// /// %0 = tensor.extract_slice %t[10][20][1] /// /// The analysis looks for matching ExtractSliceOp/InsertSliceOp pairs and lets /// those bufferize inplace in the absence of other conflicts. /// -/// Starting from an InsertSliceOp, an AllocTensorOp at the end of the insert +/// Starting from an InsertSliceOp, an tensor::EmptyOp at the end of the insert /// source's reverse use-def chain is eliminated if: /// * On the reverse use-def chain path from the InsertSliceOp to the -/// AllocTensorOp, all ops were decided to bufferize inplace and the buffer +/// tensor::EmptyOp, all ops were decided to bufferize inplace and the buffer /// relation is "equivalent" (TODO: can be relaxed if needed). -/// * The reverse use-def chain has exactly one end, which is the AllocTensorOp. +/// * The reverse use-def chain has exactly one end, which is the +/// tensor::EmptyOp. LogicalResult -mlir::bufferization::insertSliceAnchoredAllocTensorEliminationStep( +mlir::bufferization::insertSliceAnchoredEmptyTensorEliminationStep( RewriterBase &rewriter, Operation *op, AnalysisState &state) { - return eliminateAllocTensors( + return eliminateEmptyTensors( rewriter, op, state, /*anchorMatchFunc=*/ [&](OpOperand &operand, SmallVector &neededValues) { @@ -239,10 +224,10 @@ } namespace { -struct AllocTensorElimination - : public bufferization::impl::AllocTensorEliminationBase< - AllocTensorElimination> { - AllocTensorElimination() = default; +struct EmptyTensorElimination + : public bufferization::impl::EmptyTensorEliminationBase< + EmptyTensorElimination> { + EmptyTensorElimination() = default; void runOnOperation() override; @@ -253,7 +238,7 @@ }; } // namespace -void AllocTensorElimination::runOnOperation() { +void EmptyTensorElimination::runOnOperation() { Operation *op = getOperation(); OneShotBufferizationOptions options; OneShotAnalysisState state(op, options); @@ -263,11 +248,11 @@ } IRRewriter rewriter(op->getContext()); - if (failed(bufferization::insertSliceAnchoredAllocTensorEliminationStep( + if (failed(bufferization::insertSliceAnchoredEmptyTensorEliminationStep( rewriter, op, state))) signalPassFailure(); } -std::unique_ptr mlir::bufferization::createAllocTensorEliminationPass() { - return std::make_unique(); +std::unique_ptr mlir::bufferization::createEmptyTensorEliminationPass() { + return std::make_unique(); } diff --git a/mlir/test/Dialect/Linalg/one-shot-bufferize-analysis-init-tensor-elimination.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis-empty-tensor-elimination.mlir rename from mlir/test/Dialect/Linalg/one-shot-bufferize-analysis-init-tensor-elimination.mlir rename to mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis-empty-tensor-elimination.mlir --- a/mlir/test/Dialect/Linalg/one-shot-bufferize-analysis-init-tensor-elimination.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis-empty-tensor-elimination.mlir @@ -1,8 +1,4 @@ -// RUN: mlir-opt %s -eliminate-alloc-tensors -one-shot-bufferize="bufferize-function-boundaries test-analysis-only allow-return-allocs" -split-input-file | FileCheck %s - -//===----------------------------------------------------------------------===// -// AllocTensorOp elimination -//===----------------------------------------------------------------------===// +// RUN: mlir-opt %s -eliminate-empty-tensors -empty-tensor-to-alloc-tensor -one-shot-bufferize="bufferize-function-boundaries test-analysis-only allow-return-allocs" -split-input-file | FileCheck %s // CHECK-LABEL: func @buffer_forwarding_conflict func.func @buffer_forwarding_conflict(%arg0: tensor {bufferization.writable = true}, %arg1: index) -> (tensor, tensor) { @@ -10,7 +6,7 @@ // CHECK: tensor.extract_slice // CHECK-SAME: {__inplace_operands_attr__ = ["false", "none"] // Instead of allocating, share buffer with some inplace bufferization? - %0 = bufferization.alloc_tensor(%arg1) : tensor + %0 = tensor.empty(%arg1) : tensor // CHECK: linalg.fill // CHECK-SAME: {__inplace_operands_attr__ = ["none", "true"] @@ -37,7 +33,7 @@ // CHECK: tensor.extract_slice // CHECK-SAME: {__inplace_operands_attr__ = ["true", "none"] // Instead of allocating, share buffer with some inplace bufferization? - %0 = bufferization.alloc_tensor(%arg1) : tensor + %0 = tensor.empty(%arg1) : tensor // CHECK: linalg.fill // CHECK-SAME: {__inplace_operands_attr__ = ["none", "true"] diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-alloc-tensor-elimination.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir rename from mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-alloc-tensor-elimination.mlir rename to mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-alloc-tensor-elimination.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -eliminate-alloc-tensors -one-shot-bufferize="bufferize-function-boundaries allow-return-allocs" -canonicalize -split-input-file | FileCheck %s +// RUN: mlir-opt %s -eliminate-empty-tensors -empty-tensor-to-alloc-tensor -one-shot-bufferize="bufferize-function-boundaries allow-return-allocs" -canonicalize -split-input-file | FileCheck %s // CHECK: func @buffer_forwarding_conflict( // CHECK-SAME: %[[FUNC_ARG:[0-9a-zA-Z]*]]: memref @@ -16,10 +16,10 @@ // CHECK: %[[DIM:.*]] = memref.dim %[[FUNC_ARG]] // This allocs the whole dim to allow for a full clone of t. // CHECK: %[[ALLOC:.*]] = memref.alloc(%[[DIM]]) - // alloc_tensor itself does not alloc but forwards to the **second** - // insert_slice. AllocTensorOp replaces the alloc_tensor with an out-of-place + // tensor.empty itself does not alloc but forwards to the **second** + // insert_slice. The pass replaces the tensor.empty with an out-of-place // extract_slice. - %a = bufferization.alloc_tensor(%sz) : tensor + %a = tensor.empty(%sz) : tensor %f = linalg.fill ins(%f0 : f32) outs(%a : tensor) -> tensor // CHECK: memref.copy %[[FUNC_ARG]], %[[ALLOC]] : memref to memref @@ -46,11 +46,11 @@ { %f0 = arith.constant 0.0: f32 - // alloc_tensor itself does not alloc but forwards to the insert_slice. - // AllocTensorOpElimination replaces the alloc_tensor with an inplace + // tensor.empty itself does not alloc but forwards to the insert_slice. + // EmptyTensorOpElimination replaces the tensor.empty with an inplace // extract_slice. // CHECK: %[[T_SUBVIEW:.*]] = memref.subview %[[FUNC_ARG]][42] [%[[sz]]] [1] - %a = bufferization.alloc_tensor(%sz) : tensor + %a = tensor.empty(%sz) : tensor // CHECK: linalg.fill ins({{.*}} : f32) outs(%[[T_SUBVIEW]] : memref) -> tensor @@ -71,7 +71,7 @@ %c5 = arith.constant 5 : index // CHECK-NOT: memref.alloc - %blank = bufferization.alloc_tensor() : tensor<5xf32> + %blank = tensor.empty() : tensor<5xf32> // CHECK: scf.for %[[iv:.*]] = %{{.*}} to %[[sz]] step %{{.*}} { %r = scf.for %iv = %c0 to %sz step %c5 iter_args(%bb = %t) -> (tensor) { @@ -102,7 +102,7 @@ // CHECK-NOT: memref.alloc // CHECK: %[[subview:.*]] = memref.subview %[[t]][%[[idx]]] [5] [1] - %blank = bufferization.alloc_tensor() : tensor<5xf32> + %blank = tensor.empty() : tensor<5xf32> // CHECK: scf.for %[[iv:.*]] = %{{.*}} to %[[sz]] step %{{.*}} { %r = scf.for %iv = %c0 to %sz step %c5 iter_args(%bb = %t) -> (tensor) { @@ -122,14 +122,14 @@ // ----- -// AllocTensorElimination does currently not apply to chains where the type is +// EmptyTensorElimination does currently not apply to chains where the type is // changing. This test just ensures that we do not crash or generate IR that // does not verify. // CHECK-LABEL: func @shape_mismatch func.func @shape_mismatch(%t: tensor<5x6x128xf32>) -> tensor<5x6x128xf32> { %cst = arith.constant 8.0 : f32 - %0 = bufferization.alloc_tensor() : tensor<128xf32> + %0 = tensor.empty() : tensor<128xf32> %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<128xf32>) -> tensor<128xf32> %2 = tensor.expand_shape %1 [[0, 1, 2]] : tensor<128xf32> into tensor<1x1x128xf32>