diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h @@ -84,7 +84,10 @@ /// Add a new entry for `v` in the `aliasInfo` and `equivalentInfo`. In the /// beginning the alias and equivalence sets only contain `v` itself. - void createAliasInfoEntry(Value v); + void createAliasEntry(Value v); + + /// Add `aliasInfo` and `equivalentInfo` entries for a newly created op. + void createAliasEntriesForNewOp(Operation *op); /// Find all tensor values in the given operation that have undefined contents /// and store them in `undefinedTensorUses`. diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h @@ -16,6 +16,7 @@ namespace bufferization { class AnalysisState; struct BufferizationStatistics; +class OneShotAnalysisState; struct OneShotBufferizationOptions; /// A function that matches anchor OpOperands for tensor::EmptyOp elimination. @@ -35,16 +36,16 @@ /// on the reverse SSA use-def chain, starting from the OpOperand and always /// following the aliasing OpOperand, that eventually ends at a single /// tensor::EmptyOp. -LogicalResult eliminateEmptyTensors(RewriterBase &rewriter, Operation *op, - bufferization::AnalysisState &state, +LogicalResult eliminateEmptyTensors(Operation *op, OneShotAnalysisState &state, AnchorMatchFn anchorMatchFunc, RewriteFn rewriteFunc); /// Try to eliminate tensor::EmptyOps inside `op` that are anchored on an /// InsertSliceOp, i.e., if it is eventually inserted into another tensor /// (and some other conditions are met). -LogicalResult insertSliceAnchoredEmptyTensorEliminationStep( - RewriterBase &rewriter, Operation *op, bufferization::AnalysisState &state); +LogicalResult +insertSliceAnchoredEmptyTensorEliminationStep(Operation *op, + OneShotAnalysisState &state); /// Resolve RaW and other conflicts by inserting bufferization.alloc_tensor ops. /// After applying this transform, the IR can be bufferized without inserting diff --git a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp @@ -26,6 +26,30 @@ using namespace mlir; using namespace mlir::bufferization; +namespace { +/// A rewriter that tracks new ops and modified ops. +class TrackingRewriter : public RewriterBase { +public: + TrackingRewriter(MLIRContext *ctx, OneShotAnalysisState &state) + : RewriterBase(ctx), state(state) {} + + void notifyOperationInserted(Operation *op) override { + state.createAliasEntriesForNewOp(op); + newOps.push_back(op); + } + + void finalizeRootUpdate(Operation *op) override { modifiedOps.push_back(op); } + + ArrayRef getNewOps() const { return newOps; } + ArrayRef getModifiedOps() const { return modifiedOps; } + +private: + OneShotAnalysisState &state; + + SmallVector newOps, modifiedOps; +}; +} // namespace + /// Return true if all `neededValues` are in scope at the given /// `insertionPoint`. static bool @@ -105,9 +129,9 @@ /// chain, starting from the OpOperand and always following the aliasing /// OpOperand, that eventually ends at a single tensor::EmptyOp. LogicalResult mlir::bufferization::eliminateEmptyTensors( - RewriterBase &rewriter, Operation *op, AnalysisState &state, - AnchorMatchFn anchorMatchFunc, RewriteFn rewriteFunc) { - OpBuilder::InsertionGuard g(rewriter); + Operation *op, OneShotAnalysisState &state, AnchorMatchFn anchorMatchFunc, + RewriteFn rewriteFunc) { + TrackingRewriter rewriter(op->getContext(), state); WalkResult status = op->walk([&](Operation *op) { for (OpOperand &operand : op->getOpOperands()) { @@ -152,14 +176,27 @@ continue; // Replace the tensor::EmptyOp. - rewriter.replaceOp(emptyTensor.getDefiningOp(), replacement); + rewriter.replaceAllUsesWith(emptyTensor, replacement); } // Advance to the next operation. return WalkResult::advance(); }); - return failure(status.wasInterrupted()); + if (status.wasInterrupted()) + return failure(); + + DominanceInfo domInfo; + // Re-analyze modified ops. + for (Operation *op : rewriter.getModifiedOps()) + if (failed(state.analyzeSingleOp(op, domInfo))) + return failure(); + // Analyze new ops. + for (Operation *op : rewriter.getNewOps()) + if (failed(state.analyzeSingleOp(op, domInfo))) + return failure(); + + return success(); } /// Try to eliminate tensor::EmptyOps inside `op`. An tensor::EmptyOp can be @@ -188,10 +225,11 @@ /// * The reverse use-def chain has exactly one end, which is the /// tensor::EmptyOp. template -static LogicalResult insertSliceLikeAnchoredEmptyTensorEliminationStep( - RewriterBase &rewriter, Operation *op, AnalysisState &state) { +static LogicalResult +insertSliceLikeAnchoredEmptyTensorEliminationStep(Operation *op, + OneShotAnalysisState &state) { return eliminateEmptyTensors( - rewriter, op, state, + op, state, /*anchorMatchFunc=*/ [&](OpOperand &operand, SmallVector &neededValues) { auto insertSliceOp = dyn_cast(operand.getOwner()); @@ -224,12 +262,12 @@ LogicalResult mlir::bufferization::insertSliceAnchoredEmptyTensorEliminationStep( - RewriterBase &rewriter, Operation *op, AnalysisState &state) { + Operation *op, OneShotAnalysisState &state) { if (failed(insertSliceLikeAnchoredEmptyTensorEliminationStep< - tensor::InsertSliceOp>(rewriter, op, state))) + tensor::InsertSliceOp>(op, state))) return failure(); if (failed(insertSliceLikeAnchoredEmptyTensorEliminationStep< - tensor::ParallelInsertSliceOp>(rewriter, op, state))) + tensor::ParallelInsertSliceOp>(op, state))) return failure(); return success(); } @@ -258,9 +296,8 @@ return; } - IRRewriter rewriter(op->getContext()); if (failed(bufferization::insertSliceAnchoredEmptyTensorEliminationStep( - rewriter, op, state))) + op, state))) signalPassFailure(); } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp @@ -104,28 +104,7 @@ OneShotAnalysisState::OneShotAnalysisState( Operation *op, const OneShotBufferizationOptions &options) : AnalysisState(options, TypeID::get()) { - // Set up alias sets. - op->walk([&](Operation *op) { - for (Value v : op->getResults()) - if (v.getType().isa()) - createAliasInfoEntry(v); - for (Region &r : op->getRegions()) - for (Block &b : r.getBlocks()) - for (auto bbArg : b.getArguments()) - if (bbArg.getType().isa()) - createAliasInfoEntry(bbArg); - }); - - // Mark OpOperands in-place that must bufferize in-place. - op->walk([&](BufferizableOpInterface bufferizableOp) { - if (!options.isOpAllowed(bufferizableOp)) - return WalkResult::skip(); - for (OpOperand &opOperand : bufferizableOp->getOpOperands()) - if (opOperand.get().getType().isa()) - if (bufferizableOp.mustBufferizeInPlace(opOperand, *this)) - bufferizeInPlace(opOperand); - return WalkResult::advance(); - }); + op->walk([&](Operation *op) { createAliasEntriesForNewOp(op); }); } void OneShotAnalysisState::applyOnEquivalenceClass( @@ -159,8 +138,11 @@ if (inplaceBufferized.contains(&operand)) return; inplaceBufferized.insert(&operand); - for (OpResult result : getAliasingOpResults(operand)) + for (OpResult result : getAliasingOpResults(operand)) { + aliasInfo.insert(result); + aliasInfo.insert(operand.get()); aliasInfo.unionSets(result, operand.get()); + } ++statNumTensorInPlace; } @@ -170,11 +152,32 @@ ++statNumTensorOutOfPlace; } -void OneShotAnalysisState::createAliasInfoEntry(Value v) { +void OneShotAnalysisState::createAliasEntry(Value v) { aliasInfo.insert(v); equivalentInfo.insert(v); } +void OneShotAnalysisState::createAliasEntriesForNewOp(Operation *op) { + // Set up alias sets. + for (Value v : op->getResults()) + if (v.getType().isa()) + createAliasEntry(v); + for (Region &r : op->getRegions()) + for (Block &b : r.getBlocks()) + for (auto bbArg : b.getArguments()) + if (bbArg.getType().isa()) + createAliasEntry(bbArg); + + // Mark OpOperands in-place that must bufferize in-place. + auto bufferizableOp = getOptions().dynCastBufferizableOp(op); + if (!bufferizableOp) + return; + for (OpOperand &opOperand : bufferizableOp->getOpOperands()) + if (opOperand.get().getType().isa()) + if (bufferizableOp.mustBufferizeInPlace(opOperand, *this)) + bufferizeInPlace(opOperand); +} + // Gather yielded tensors in `yieldedTensors` by querying all aliases. This is // to ensure that such information is available during bufferization time. // Alias information can no longer be queried once we have started modifying diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir @@ -1,4 +1,10 @@ -// RUN: mlir-opt %s -eliminate-empty-tensors -empty-tensor-to-alloc-tensor -one-shot-bufferize="bufferize-function-boundaries allow-return-allocs" -canonicalize -split-input-file | FileCheck %s +// RUN: mlir-opt %s -eliminate-empty-tensors -empty-tensor-to-alloc-tensor \ +// RUN: -one-shot-bufferize="bufferize-function-boundaries allow-return-allocs" \ +// RUN: -canonicalize -split-input-file | FileCheck %s + +// RUN: mlir-opt %s -test-interleaved-empty-tensor-elimination -cse \ +// RUN: -canonicalize -loop-invariant-code-motion -split-input-file | \ +// RUN: FileCheck %s // CHECK: func @buffer_forwarding_conflict( // CHECK-SAME: %[[FUNC_ARG:[0-9a-zA-Z]*]]: memref diff --git a/mlir/test/lib/Dialect/Bufferization/CMakeLists.txt b/mlir/test/lib/Dialect/Bufferization/CMakeLists.txt --- a/mlir/test/lib/Dialect/Bufferization/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Bufferization/CMakeLists.txt @@ -1,5 +1,6 @@ # Exclude tests from libMLIR.so add_mlir_library(MLIRBufferizationTestPasses + TestEmptyTensorElimination.cpp TestTensorCopyInsertion.cpp EXCLUDE_FROM_LIBMLIR @@ -9,4 +10,5 @@ MLIRBufferizationTransforms MLIRIR MLIRPass + MLIRTensorDialect ) diff --git a/mlir/test/lib/Dialect/Bufferization/TestEmptyTensorElimination.cpp b/mlir/test/lib/Dialect/Bufferization/TestEmptyTensorElimination.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Dialect/Bufferization/TestEmptyTensorElimination.cpp @@ -0,0 +1,95 @@ +//===- TestEmptyTensorElimination.cpp ---------------------------*- c++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" +#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" +#include "mlir/Dialect/Bufferization/Transforms/Transforms.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/Dominance.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; +using namespace mlir::bufferization; + +namespace { +struct TestInterleavedEmptyTensorEliminationPass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( + TestInterleavedEmptyTensorEliminationPass) + + TestInterleavedEmptyTensorEliminationPass() = default; + TestInterleavedEmptyTensorEliminationPass( + const TestInterleavedEmptyTensorEliminationPass &pass) + : PassWrapper(pass) {} + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + StringRef getArgument() const final { + return "test-interleaved-empty-tensor-elimination"; + } + StringRef getDescription() const final { + return "Module pass to test Tensor Copy Insertion"; + } + + void runOnOperation() override { + OneShotBufferizationOptions options; + options.allowReturnAllocs = true; + options.bufferizeFunctionBoundaries = true; + OneShotAnalysisState state(getOperation(), options); + + // Analyze ops. + if (failed(analyzeOp(getOperation(), state))) + return signalPassFailure(); + + // Eliminate tensor.empty ops. + if (failed(insertSliceAnchoredEmptyTensorEliminationStep(getOperation(), + state))) + return signalPassFailure(); + + // Rewrite remaining tensor.empty ops to bufferization.alloc_tensor ops. + OpBuilder b(getOperation()->getContext()); + DominanceInfo domInfo; + WalkResult status = getOperation()->walk([&](tensor::EmptyOp emptyOp) { + b.setInsertionPoint(emptyOp); + if (!emptyOp->getUsers().empty()) { + Value allocTensor = b.create( + emptyOp.getLoc(), emptyOp.getType(), emptyOp.getDynamicSizes()); + state.createAliasEntriesForNewOp(allocTensor.getDefiningOp()); + SmallVector uses; + for (OpOperand &use : emptyOp.getResult().getUses()) + uses.push_back(&use); + for (OpOperand *use : uses) { + use->set(allocTensor); + if (failed(state.analyzeSingleOp(use->getOwner(), domInfo))) + return WalkResult::interrupt(); + } + emptyOp.replaceAllUsesWith(allocTensor); + } + emptyOp->erase(); + return WalkResult::skip(); + }); + if (status.wasInterrupted()) + return signalPassFailure(); + + // Insert tensor copies and bufferize. + if (failed(insertTensorCopies(getOperation(), state))) + return signalPassFailure(); + if (failed(bufferizeOp(getOperation(), options, /*copyBeforeWrite=*/false))) + return signalPassFailure(); + } +}; +} // namespace + +namespace mlir::test { +void registerTestInterleavedEmptyTensorEliminationPass() { + PassRegistration(); +} +} // namespace mlir::test diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -85,6 +85,7 @@ void registerTestFooAnalysisPass(); void registerTestComposeSubView(); void registerTestMultiBuffering(); +void registerTestInterleavedEmptyTensorEliminationPass(); void registerTestIntRangeInference(); void registerTestIRVisitorsPass(); void registerTestGenericIRVisitorsPass(); @@ -194,6 +195,7 @@ mlir::test::registerTestFooAnalysisPass(); mlir::test::registerTestComposeSubView(); mlir::test::registerTestMultiBuffering(); + mlir::test::registerTestInterleavedEmptyTensorEliminationPass(); mlir::test::registerTestIntRangeInference(); mlir::test::registerTestIRVisitorsPass(); mlir::test::registerTestGenericIRVisitorsPass(); diff --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel @@ -712,6 +712,7 @@ "//mlir:BufferizationTransforms", "//mlir:IR", "//mlir:Pass", + "//mlir:TensorDialect", ], )