diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorBase.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorBase.td --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorBase.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorBase.td @@ -43,6 +43,12 @@ dialect), and does not live in this dialect. }]; + + let hasOperationAttrVerify = 1; + let extraClassDeclaration = [{ + constexpr const static ::llvm::StringLiteral + kInplaceableAttrName = "tensor.inplaceable"; + }]; } #endif // TENSOR_BASE diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h --- a/mlir/include/mlir/Transforms/Passes.h +++ b/mlir/include/mlir/Transforms/Passes.h @@ -53,6 +53,12 @@ std::unique_ptr<Pass> createPromoteBuffersToStackPass(std::function<bool(Value)> isSmallAlloc); +/// Creates a pass that bufferizes the body of a FuncOp and tries to reuse the +/// buffers for those arguments that: +/// a) have been annotated 'inplaceable' and +/// b) whose buffer uses would be free of memory hazards. +std::unique_ptr<Pass> createComprehensiveFuncBufferizePass(); + /// Creates a pass that finalizes a partial bufferization by removing remaining /// tensor_load and buffer_cast operations. std::unique_ptr<FunctionPass> createFinalizingBufferizePass(); diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td --- a/mlir/include/mlir/Transforms/Passes.td +++ b/mlir/include/mlir/Transforms/Passes.td @@ -380,6 +380,20 @@ ]; } +def ComprehensiveFuncBufferize : FunctionPass<"comprehensive-func-bufferize"> { + let summary = "Bufferize the body of a FuncOp and tries to reuse the buffers " + "for those arguments that a) have been annotated 'inplaceable' and " + "b) whose buffer uses would be free of memory hazards"; + let description = [{ + This pass implements a cross-dialect bufferization approach and performs an + analysis to determine which op operands and results may be bufferized in the + same buffers. The analysis is performed on SSA use-def chains starting from + function operands that are annotated with the 'inplaceable' attribute + }]; + let constructor = "mlir::createComprehensiveFuncBufferizePass()"; + let dependentDialects = ["linalg::LinalgDialect", "memref::MemRefDialect"]; +} + def Inliner : Pass<"inline"> { let summary = "Inline function calls"; let constructor = "mlir::createInlinerPass()"; diff --git a/mlir/lib/Dialect/Tensor/IR/TensorDialect.cpp b/mlir/lib/Dialect/Tensor/IR/TensorDialect.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorDialect.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorDialect.cpp @@ -8,6 +8,7 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/FunctionSupport.h" #include "mlir/Transforms/InliningUtils.h" #include "llvm/ADT/TypeSwitch.h" @@ -147,6 +148,8 @@ // TensorDialect Dialect Interfaces //===----------------------------------------------------------------------===// +constexpr const ::llvm::StringLiteral TensorDialect::kInplaceableAttrName; + namespace { struct TensorInlinerInterface : public DialectInlinerInterface { using DialectInlinerInterface::DialectInlinerInterface; @@ -196,3 +199,19 @@ if (succeeded(generatedAttributePrinter(attr, printer))) return; } + +LogicalResult TensorDialect::verifyOperationAttribute(Operation *op, + NamedAttribute attr) { + if (attr.first == TensorDialect::kInplaceableAttrName) { + if (!attr.second.isa<BoolAttr>()) { + return op->emitError() << "'" << TensorDialect::kInplaceableAttrName + << "' is expected to be a boolean attribute"; + } + if (!op->hasTrait<OpTrait::FunctionLike>()) + return op->emitError() << "expected " << attr.first + << " to be used on function-like operations"; + return success(); + } + return op->emitError() << "attribute '" << attr.first + << "' not supported by the tensor dialect"; +} diff --git a/mlir/lib/Transforms/ComprehensiveBufferization.cpp b/mlir/lib/Transforms/ComprehensiveBufferization.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Transforms/ComprehensiveBufferization.cpp @@ -0,0 +1,815 @@ +//===- ComprehensiveBufferization.cpp - Single pass bufferization ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Perform bufferization within the function boundaries. Bufferization occurs +// by: +// a. performing an inPlace analysis `inPlaceAnalysisFuncOpInternals` +// which marks each operation within the function with the +// `kInPlaceResultsAttrName` attribute. +// b. traversing each operation in the function and rewriting it in +// buffer form and keeping a BlockAndValueMapping mapping of the +// rewrites. New allocations are introduced during this step. +// TODO: Allocation + depending op hoisting to outermost enclosing +// sequential scope. +// c. at the end of this bufferization, 2 cases may occur: +// * inplaceable function arguments may be reused in place after the +// function itself has been bufferized. This is encoded by IR resembling: +// +// ``` +// #map = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)> +// func @foo(%A: tensor<?xf32> {tensor.inplaceable = true}) -> tensor<?xf32> { +// %0 = memref.buffer_cast %A : memref<?xf32, #map> +// // ... uses of %0 +// %res = memref.tensor_load %0 : memref<?xf32, #map> +// return %res : tensor<?xf32> +// } +// ``` +// +// this is the cue for the bufferization of the function foo (and calls to +// it) may bufferize to `func @foo(%A: memref<?xf32, some_layout>)`. +// To fully achieve bufferization, an additional analysis is needed to +// determine whether function argument/operand pairs bufferize to a single +// inplace buffer argument (i.e. functions may return tensors in arbitrary +// order that may not match argument numbers). +// * results that don't map to an inplaceable function argument must be +// allocated. Since memref semantics wrt ownership of the underlying +// memory region are not well-defined, comprehensive bufferization chooses +// to perform allocations in a scoped fashion: returning memrefs is always +// considered illegal. Such scenarios are encoded by IR resembling: +// +// ``` +// #map = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)> +// func @foo(%A: tensor<?xf32> {tensor.inplaceable = true}) -> tensor<?xf32> { +// %0 = memref.buffer_cast %A : memref<?xf32, #map> +// %1 = memref.dim %0, %c0 : memref<?xf32, #map> +// %2 = memref.alloc(%1) : memref<?xf32> +// %3 = memref.cast %2 : memref<?xf32> to memref<?xf32, #map> +// // ... uses of %3 +// memref.dealloc %2 : memref<?xf32, #map> +// %res = memref.tensor_load %3 : memref<?xf32, #map> +// return %res : tensor<?xf32> +// } +// ``` +// +// this is the cue for the bufferization of the function foo (and calls to +// it) that it must bufferize to +// `func @foo(%A: memref<?xf32, some_layout>, +// %B: memref<?xf32, some_layout>)` (i.e. make a cloned +// allocation of the result tensor) +// To fully achieve bufferization, the alloc/dealloc pair must be lifted +// out of the function at each call site. +// +// Lastly, note that layout map chosen to bufferize is the most dynamic +// canonical strided layout of the proper rank. This ensures compatibility with +// expected layouts after transformations. Combinations of memref.cast + +// canonicalization are responsible for clean ups. + +#include "PassDetail.h" +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/Operation.h" +#include "mlir/Interfaces/LoopLikeInterface.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/BufferUtils.h" +#include "mlir/Transforms/Passes.h" + +#include "llvm/ADT/ScopeExit.h" +#include "llvm/ADT/TypeSwitch.h" + +#define DEBUG_TYPE "comprehensive-func-bufferize" + +using namespace mlir; +using namespace linalg; +using namespace tensor; + +#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") + +//===----------------------------------------------------------------------===// +// Bufferization-specific attribute manipulation. +//===----------------------------------------------------------------------===// + +/// Attribute marker to specify op results that can be bufferized inPlace. +constexpr StringLiteral kInPlaceResultsAttrName = "__inplace_results_attr__"; + +// TODO: proper enum. +enum class InPlaceSpec { + False, + True, + None, +}; + +static StringRef stringify(InPlaceSpec val) { + switch (val) { + case InPlaceSpec::False: + return "false"; + case InPlaceSpec::True: + return "true"; + case InPlaceSpec::None: + return "none"; + } + return ""; +} + +static Optional<InPlaceSpec> symbolize(StringRef str) { + return StringSwitch<Optional<InPlaceSpec>>(str) + .Case("false", InPlaceSpec::False) + .Case("true", InPlaceSpec::True) + .Case("none", InPlaceSpec::None) + .Default(None); +} + +/// Factor out the logic that matches tied OpResult to BlockArgument. +// TODO: TiedOpResultInterface. +static OpResult getTiedOpResult(BlockArgument &bbArg) { + if (auto funcOp = dyn_cast<FuncOp>(bbArg.getOwner()->getParentOp())) + return OpResult(); + // TODO: more ops and interfaces. + return OpResult(); +} + +/// Factor out the logic that matches tied OpResult to OpOperand. +// TODO: TiedOpResultInterface. +static OpResult getTiedOpResult(OpOperand &opOperand) { + Operation *op = opOperand.getOwner(); + if (auto linalgOp = dyn_cast<LinalgOp>(op)) { + if (opOperand.getOperandNumber() < linalgOp.getNumInputs()) + return OpResult(); + return linalgOp->getResult(opOperand.getOperandNumber() - + linalgOp.getNumInputs()); + } + // Terminators have no results. + if (op->hasTrait<mlir::OpTrait::IsTerminator>()) + return OpResult(); + // TODO: more ops and interfaces. + return OpResult(); +} + +/// Mark whether OpResult can actually be bufferized inplace. If `inPlace` is +/// `InPlaceSpec::True`, the use-def chain analysis has guaranteed that no +/// subsequent write to the tensor value occurs and the result can be buferized +/// inPlace. +static void setInPlaceOpResult(OpResult opResult, + InPlaceSpec inPlace = InPlaceSpec::True) { + if (!opResult) + return; + + Operation *op = opResult.getOwner(); + auto attr = + op->getAttr(kInPlaceResultsAttrName).dyn_cast_or_null<ArrayAttr>(); + SmallVector<StringRef> inPlaceVector = + attr ? SmallVector<StringRef>( + llvm::to_vector<4>(attr.getAsValueRange<StringAttr>())) + : SmallVector<StringRef>(op->getNumResults(), + stringify(InPlaceSpec::None)); + LLVM_DEBUG(DBGS() << "Set inPlace=" << stringify(inPlace) << ": " << *op + << " @idx=" << opResult.getResultNumber() << "\n"); + inPlaceVector[opResult.getResultNumber()] = stringify(inPlace); + op->setAttr(kInPlaceResultsAttrName, + OpBuilder(op).getStrArrayAttr(inPlaceVector)); +} + +/// Get the InPlaceSpec attribute entry `kInPlaceResultsAttrName` for +/// `opResult`. If the result is `InPlaceSpec::True`, the use-def chain analysis +/// has guaranteed that no subsequent read of the tensor value occurs and the +/// result can be buferized inPlace. +/// If no InPlaceSpec attribute has been set for `opResult`, return +/// InPlaceSpec::None. +static InPlaceSpec getInPlace(OpResult opResult) { + if (!opResult) + return InPlaceSpec::None; + + Operation *op = opResult.getOwner(); + auto attr = + op->getAttr(kInPlaceResultsAttrName).dyn_cast_or_null<ArrayAttr>(); + if (!attr) + return InPlaceSpec::None; + + // Must return a proper value. + return *symbolize(*(attr.getAsValueRange<StringAttr>().begin() + + opResult.getResultNumber())); +} + +/// Get inPlace information for `bbArg`. +/// If it does not come from a function, return InPlaceSpec::False. +static InPlaceSpec getInPlace(BlockArgument bbArg) { + auto funcOp = dyn_cast<FuncOp>(bbArg.getOwner()->getParentOp()); + if (!funcOp) + return InPlaceSpec::False; + auto attr = funcOp.getArgAttrOfType<BoolAttr>( + bbArg.getArgNumber(), TensorDialect::kInplaceableAttrName); + if (!attr) + return InPlaceSpec::None; + + return attr.getValue() ? InPlaceSpec::True : InPlaceSpec::False; +} + +//===----------------------------------------------------------------------===// +// Bufferization-specific BlockAndValueMapping support with debugging. +//===----------------------------------------------------------------------===// + +/// Wrapper for better debugging. +static void map(BlockAndValueMapping &bvm, ValueRange key, ValueRange value) { + if (key.empty()) + return; + LLVM_DEBUG(DBGS() << "Map: " << key.front() << " to " << value.front() + << "\n"); + return bvm.map(key, value); +} + +/// Wrapper for better debugging. +static void map(BlockAndValueMapping &bvm, Value key, Value value) { + LLVM_DEBUG(DBGS() << "Map: " << key << " to " << value << "\n"); + return bvm.map(key, value); +} + +/// Wrapper for better debugging. +static Value lookup(BlockAndValueMapping &bvm, Value key) { + // TODO: if key comes from bbArg, forward. + assert(key.getType().isa<TensorType>()); + if (!bvm.lookupOrNull(key)) { + if (auto bbArg = key.dyn_cast<BlockArgument>()) { + if (isa<FuncOp>(key.getParentBlock()->getParentOp())) + key.getParentBlock()->getParentOp()->dump(); + else + key.getParentBlock()->getParentOp()->getParentOfType<FuncOp>()->dump(); + bbArg.getOwner()->getParentOp()->dump(); + } else { + key.getDefiningOp()->getParentOfType<FuncOp>()->dump(); + } + llvm::errs() << "NO VALUE FOR KEY: " << key << "\n"; + abort(); + } + return bvm.lookup(key); +} + +//===----------------------------------------------------------------------===// +// Bufferization-specific support. +//===----------------------------------------------------------------------===// + +/// Determine whether any subsequent read of the tensor `opOperand` may occur. +/// For now, this assumes any use is a read. If any use of the tensor does not +/// properly dominate `opOperand.getOwner()`, then the tensor cannot be +/// bufferized inPlace. +// TODO: For now, this assumes any use is a read. Refine this. +bool hasInterferingTensorRead(OpOperand &opOperand, + const DominanceInfo &domInfo) { + if (!opOperand.get().getType().isa<RankedTensorType>()) + return false; + for (auto &use : opOperand.get().getUses()) { + Operation *user = use.getOwner(); + + // If properly dominate, there is a clear sequence point and we can dismiss + // read. + if (domInfo.properlyDominates(user, opOperand.getOwner())) + continue; + // Otherwise, we need to analyze self-dependencies, for now just let it go. + // TODO: proper self-dependence analysis. + if (domInfo.dominates(user, opOperand.getOwner())) + continue; + if (user == opOperand.getOwner() && + use.getOperandNumber() == opOperand.getOperandNumber()) + continue; + LLVM_DEBUG(DBGS() << "found interfering read operand #" + << opOperand.getOperandNumber() + << " in op: " << *opOperand.getOwner() << "\n"); + return true; + } + LLVM_DEBUG(DBGS() << "no interfering read\n"); + return false; +} + +/// Return false if either: +/// 1. `opOperand` is produced by a constant op. For now this is assumed to be +/// bufferized to a GlobalMemrefOp that cannot be written. Generalize in the +/// future. +/// 2.`opOperand` is a BlockArgument of a FuncOp that is not known to be +/// bufferizable inplace. +/// 3.`opOperand` has an interfering tensor read. +/// Return true otherwise. +bool isBufferizableInPlace(OpOperand &opOperand, const DominanceInfo &domInfo) { + // Constant tensors are deemed not bufferizable for now. + if (auto constantOp = + dyn_cast_or_null<ConstantOp>(opOperand.get().getDefiningOp())) + return !constantOp.getResult().getType().isa<RankedTensorType>(); + if (auto bbArg = opOperand.get().dyn_cast<BlockArgument>()) { + // Function argument that may not be written-to needs to be copied by + // user(s). + // TODO: better propagate the fact that we want a single clone inside the + // function. Atm every user that wants to write inplace will create its own + // alloc, irrespective of whether or not interfering reads occur. + if (isa<FuncOp>(bbArg.getOwner()->getParentOp())) + return getInPlace(bbArg) == InPlaceSpec::True; + } + return !hasInterferingTensorRead(opOperand, domInfo); +} + +//===----------------------------------------------------------------------===// +// Bufferization-specific MemRefType support. +//===----------------------------------------------------------------------===// + +/// Return a contiguous MemRefType (i.e. with canonical/empty layout map) with +/// the same shape as `shapedType` and specified `layout` and `addressSpace`. +static MemRefType getContiguousMemRefType(ShapedType shapedType, + ArrayRef<AffineMap> layout = {}, + unsigned addressSpace = 0) { + if (RankedTensorType tensorType = shapedType.dyn_cast<RankedTensorType>()) + return MemRefType::get(tensorType.getShape(), tensorType.getElementType(), + layout, addressSpace); + MemRefType memrefType = shapedType.cast<MemRefType>(); + return MemRefType::get(memrefType.getShape(), memrefType.getElementType(), + layout, addressSpace); +} + +/// Return a contiguous MemRefType (i.e. with canonical/empty layout map) with +/// the same shape as `shapedType` and specified `layout` and `addressSpace` or +/// an UnrankedMemRefType otherwise. +static Type getContiguousOrUnrankedMemRefType(Type type, + ArrayRef<AffineMap> layout = {}, + unsigned addressSpace = 0) { + if (type.isa<RankedTensorType, MemRefType>()) + return getContiguousMemRefType(type.cast<ShapedType>(), layout, + addressSpace); + assert(layout.empty() && "expected empty layout with UnrankedMemRefType"); + return UnrankedMemRefType::get(getElementTypeOrSelf(type), addressSpace); +} + +/// Return a MemRefType to which the `tensorType` can be bufferized in a +/// composable fashion. The layout must be the most dynamic possible and +/// canonicalize away once bufferization is finished. +static MemRefType getDynamicMemRefType(RankedTensorType tensorType, + unsigned addressSpace = 0) { + // TODO: address space decisions to connect with the actual alloc. + int64_t dynamicOffset = ShapedType::kDynamicStrideOrOffset; + SmallVector<int64_t> dynamicStrides(tensorType.getRank(), + ShapedType::kDynamicStrideOrOffset); + AffineMap stridedLayout = makeStridedLinearLayoutMap( + dynamicStrides, dynamicOffset, tensorType.getContext()); + return MemRefType::get(tensorType.getShape(), tensorType.getElementType(), + stridedLayout, addressSpace); +} + +// Transfer all `dim` ops on `tensor` to `memref`. +static void transferDimOpsToMemref(Value tensor, Value memref) { + for (OpOperand &opOperand : llvm::make_early_inc_range(tensor.getUses())) { + if (isa<memref::DimOp>(opOperand.getOwner())) { + opOperand.set(memref); + } + } +} + +//===----------------------------------------------------------------------===// +// Bufferization-specific inPlace pattern matching support. +//===----------------------------------------------------------------------===// + +/// First assign `op` if `slice.back()` isa `T`, then check condition. +/// If anything fails just return failure. Otherwise update `sliceRef` by +/// dropping `sliceRef.back()`, then return success(). +template <typename T> +static LogicalResult +matchAndDropBack(ArrayRef<Operation *> &sliceRef, T &op, + llvm::function_ref<LogicalResult(T)> condition = nullptr) { + if (sliceRef.empty()) + return failure(); + op = dyn_cast<T>(sliceRef.back()); + if (!op || (condition && failed(condition(op)))) + return failure(); + sliceRef = sliceRef.drop_back(); + return success(); +} + +/// Detect whether `v` has a single user that is exactly `user`. +static LogicalResult isInPlaceSingleUseOp(Value v, Operation *user) { + if (!v.hasOneUse() || *v.getUsers().begin() != user) + return failure(); + return success(); +} + +//===----------------------------------------------------------------------===// +// Bufferization-specific scoped alloc/dealloc insertion support. +//===----------------------------------------------------------------------===// + +/// Create an Allocop/DeAllocOp pair, where the AllocOp is after +/// `shapedValue.getDefiningOp` (or at the top of the block in case of a bbArg) +/// and the DeallocOp is at the end of the block. +static Value createNewAllocDeallocPairForShapedValue( + OpBuilder &b, Location loc, Value shapedValue, + SmallVector<Value, 4> dynOperands = {}) { + // Take a guard before anything else. + OpBuilder::InsertionGuard g(b); + + // TODO: non-zero address space. + // TODO: layout information if relevant. + // Cannot allocate an unranked memref so just always go for the contiguous + // form. + MemRefType allocMemRefType = + getContiguousMemRefType(shapedValue.getType().cast<ShapedType>()); + assert(shapedValue.getType().isa<ShapedType>()); + MemRefType memRefType = shapedValue.getType().dyn_cast<MemRefType>(); + memRefType = memRefType ? memRefType : allocMemRefType; + + if (auto bbArg = shapedValue.dyn_cast<BlockArgument>()) { + b.setInsertionPointToStart(bbArg.getOwner()); + loc = bbArg.getOwner()->getParentOp()->getLoc(); + } else { + b.setInsertionPointAfter(shapedValue.getDefiningOp()); + loc = shapedValue.getDefiningOp()->getLoc(); + } + + // If the dynOperands are not passed explicity, copmpute them. + // This circumvents currently missing dim(init_tensor) canonicalizations. + if (dynOperands.empty()) { + for (auto dim : llvm::enumerate(memRefType.getShape())) + if (dim.value() == ShapedType::kDynamicSize) + dynOperands.push_back( + b.create<memref::DimOp>(loc, shapedValue, dim.index())); + } + Value allocated = + b.create<memref::AllocOp>(loc, allocMemRefType, dynOperands); + Value casted = allocated; + if (memRefType != allocMemRefType) + casted = b.create<memref::CastOp>(loc, memRefType, allocated); + b.setInsertionPoint(allocated.getParentBlock()->getTerminator()); + b.create<memref::DeallocOp>(loc, allocated); + return casted; +} + +//===----------------------------------------------------------------------===// +// Bufferization-specific inPlace analysis support. +//===----------------------------------------------------------------------===// + +/// Detect the simple terminator pattern: +/// ``` +/// candidate -> single-result with single-use linalg op -> term +/// ``` +template <typename ContainerOp, typename TerminatorOp> +static LogicalResult detectLinalgToTerminator(Operation *parentOp, + BlockArgument candidate, + ArrayRef<Operation *> &sliceRef) { + if (!parentOp || !isa<ContainerOp>(parentOp)) + return failure(); + + ArrayRef<Operation *> tmpSliceRef = sliceRef; + + TerminatorOp terminatorOp; + // Match returnOp and update tmpSliceRef. + if (failed(matchAndDropBack(tmpSliceRef, terminatorOp))) { + LLVM_DEBUG(DBGS() << "FAIL: linalg to term pattern -> slice must end " + "with a known terminator\n"); + return failure(); + } + + LinalgOp linalgOp; + // Match linalgOp with a single output tensor for now. + // TODO: support more than single result. + if (failed(matchAndDropBack(tmpSliceRef, linalgOp)) || + linalgOp.getNumOutputTensors() != 1 || + failed(isInPlaceSingleUseOp(linalgOp->getResult(0), terminatorOp))) { + LLVM_DEBUG(DBGS() << "FAIL: linalg to term pattern -> slice must end " + "with single-result linalg op\n"); + return failure(); + } + + // Commit what has been detected. + // TODO: support more than single result. + setInPlaceOpResult(linalgOp->getResult(0)); + tmpSliceRef = sliceRef; + LLVM_DEBUG(DBGS() << "SUCCESS: linalg to term pattern\n"); + + return success(); +} + +/// The following uses internal knowledge of the position of tied operand / +/// results. +// TODO: TieOperandInterface. +static void propagateInPlace(const SmallVector<OpOperand *> &initalWorklist, + const DominanceInfo &domInfo) { + LLVM_DEBUG(DBGS() << "\n\n"); + LLVM_DEBUG(DBGS() << "Start propagateInPlace from initial WL\n"); + for (OpOperand *operand : initalWorklist) + LLVM_DEBUG(DBGS() << "WL item: " << operand->get() << " used by " + << *operand->getOwner() << "\n"); + SmallVector<OpOperand *> worklist(initalWorklist); + for (unsigned idx = 0; idx < worklist.size(); ++idx) { + // TODO: bail on subtensor/subtensor_insert and vector.transfer_read/write + // that should have been already captured in destructive update patterns? + OpOperand &operand = *worklist[idx]; + LLVM_DEBUG(DBGS() << "WL item: " << *operand.getOwner() << "\n"); + // If the owner turns out to be a CallOp without + // `kWriteableFuncBufferArgsAttrName` this will be a noop. + if (operand.get().getType().isa<TensorType>() && + isBufferizableInPlace(operand, domInfo)) { + LLVM_DEBUG(DBGS() << "bufferizable inplace\n"); + setInPlaceOpResult(getTiedOpResult(operand)); + } + LLVM_DEBUG(DBGS() << "propagatedInPlace: " << *operand.getOwner() << "\n"); + // use can have interfering reads that prevent it from being written inPlace + // but the values it produces are still themselves candidates for inPlace at + // their point of use. + for (Value v : operand.getOwner()->getResults()) { + LLVM_DEBUG(DBGS() << "propagate result: " << v << "\n"); + for (auto &use : v.getUses()) { + LLVM_DEBUG(DBGS() << "add use to WL: " << use.get() << "\n"); + worklist.push_back(&use); + } + } + } + LLVM_DEBUG(DBGS() << "\n\n"); +} + +static void propagateInPlace(OpOperand &opOperand, + const DominanceInfo &domInfo) { + SmallVector<OpOperand *> worklist{&opOperand}; + propagateInPlace(worklist, domInfo); +} + +static void propagateInPlace(BlockArgument &bbArg, + const DominanceInfo &domInfo) { + SmallVector<OpOperand *> worklist; + for (auto &use : bbArg.getUses()) + worklist.push_back(&use); + propagateInPlace(worklist, domInfo); +} + +/// Iterate over bbArgs of `parentOp` and determine if they are the root of a +/// known destructive update chain. Such a destructive update is related to +/// traditional loop nest + memory analysis but provides a simpler SSA use-def +/// chain-based abstraction. +static void destructiveUpdateAnalysis(Block *block, + const DominanceInfo &domInfo) { + Operation *parentOp = block->getParentOp(); + for (BlockArgument candidate : block->getArguments()) { + LLVM_DEBUG(llvm::dbgs() << "\n\n"); + LLVM_DEBUG(DBGS() << "Destructive update analysis on candidate: " + << candidate << "\nof:\n" + << *parentOp << "\n"); + + if (!candidate.getType().isa<ShapedType>()) { + LLVM_DEBUG(DBGS() << "Not a tensor\n"); + continue; + } + + // FuncOp arguments must be inplaceable otherwise they cannot be the root of + // a destructive update chain. + if (isa<FuncOp>(parentOp) && getInPlace(candidate) != InPlaceSpec::True) + continue; + + llvm::SetVector<Operation *> slice; + getForwardSlice(candidate, &slice, [&](Operation *op) { + // Skip any extra nesting between parentOp and op. + return op == parentOp || op->getBlock()->getParentOp() == parentOp; + }); + + LLVM_DEBUG(DBGS() << "Slice:\n"); + for (auto *op : slice) + LLVM_DEBUG(DBGS() << *op << "\n"); + + ArrayRef<Operation *> sliceRef = slice.getArrayRef(); + bool failedDetectingDestructiveUpdate = + // func / return inplace patterns. + failed(detectLinalgToTerminator<FuncOp, ReturnOp>(parentOp, candidate, + sliceRef)); + if (failedDetectingDestructiveUpdate) { + LLVM_DEBUG(DBGS() << "Failed to detect a destructive update pattern\n"); + continue; + } + + propagateInPlace(candidate, domInfo); + } +} + +//===----------------------------------------------------------------------===// +// Bufferization as simple BlockAndValueMapping rewrites. +//===----------------------------------------------------------------------===// + +/// Helper function for LinalgOp bufferization. +/// Operate on mixed tensor + buffer Linalg ops for progressive bufferization. +/// Allocate the output buffers for the remaining tensor output operands of +/// the Linalg op. If the tensor is an "init" tensor (i.e. its value is +/// actually used in the payload region), we additionally copy the original +/// value into the newly allocated buffer. +static LogicalResult +allocateBuffersForResults(OpBuilder &b, Location loc, LinalgOp op, + SmallVectorImpl<Value> &resultBuffers, + BlockAndValueMapping &bvm) { + // Take a guard before anything else. + OpBuilder::InsertionGuard g(b); + + // Lazily compute loopRanges. + SmallVector<Range, 4> loopRanges; + + // Linalg invariant: output tensors and result match 1-1. + assert(op.getNumOutputTensors() == op->getNumResults()); + for (auto &opOperand : op.getOutputOpOperands()) { + Value output = opOperand.get(); + if (output.getType().isa<MemRefType>()) { + resultBuffers.push_back(output); + continue; + } + + // If output tensor is marked inPlace, just use the buffer. + // The following uses internal knowledge of the position of tied operand / + // results. A proper TieOperandInterface would be much better. + if (getInPlace(getTiedOpResult(opOperand)) == InPlaceSpec::True) { + resultBuffers.push_back(lookup(bvm, output)); + continue; + } + + Value dimTensor = bvm.lookupOrDefault(output); + Value alloc = createNewAllocDeallocPairForShapedValue(b, loc, dimTensor); + b.setInsertionPointAfter(alloc.getDefiningOp()); + resultBuffers.push_back(alloc); + + // Additionally, if the output buffer is used, clone its value for now. + if (op.payloadUsesValueFromOpOperand(&opOperand)) + b.create<CopyOp>(loc, lookup(bvm, output), alloc); + } + map(bvm, op->getResults(), resultBuffers); + for (auto it : llvm::zip(op->getResults(), resultBuffers)) { + transferDimOpsToMemref(std::get<0>(it), std::get<1>(it)); + } + return success(); +} + +static void finalizeBufferAllocation(OpBuilder &b, LinalgOp op, + ValueRange inputs, ValueRange outputs, + BlockAndValueMapping &bvm) { + SmallVector<Value, 8> newOperands = inputs; + newOperands.append(outputs.begin(), outputs.end()); + auto otherOperands = op.getAssumedNonShapedOperands(); + newOperands.append(otherOperands.begin(), otherOperands.end()); + Location loc = op.getLoc(); + op.clone(b, loc, /*resultTypes=*/TypeRange{}, newOperands); + + // Replace the results of the old op with the new output buffers. + map(bvm, op.getOperation()->getResults(), outputs); + for (auto it : llvm::zip(op.getOperation()->getResults(), outputs)) { + transferDimOpsToMemref(std::get<0>(it), std::get<1>(it)); + } + + if (!op.hasTensorSemantics()) + op->erase(); +} + +/// Generic conversion for any LinalgOp. +/// Operate on mixed tensor + buffer Linalg ops for progressive bufferization. +static LogicalResult convertAnyLinalgOp(OpBuilder &b, LinalgOp op, + BlockAndValueMapping &bvm) { + // Take a guard before anything else. + OpBuilder::InsertionGuard g(b); + + if (op.hasBufferSemantics()) + return failure(); + + LLVM_DEBUG(DBGS() << "convert: " << *op << "\n"); + + b.setInsertionPoint(op); + Location loc = op.getLoc(); + SmallVector<Value, 2> newInputBuffers; + newInputBuffers.reserve(op.getNumInputs()); + for (Value v : op.getInputs()) { + newInputBuffers.push_back(lookup(bvm, v)); + } + SmallVector<Value, 2> newOutputBuffers; + if (failed(allocateBuffersForResults(b, loc, op, newOutputBuffers, bvm))) + assert(false); + + // Delegate to the linalg generic pattern. + if (auto genericOp = dyn_cast<GenericOp>(op.getOperation())) { + finalizeBufferAllocation(b, genericOp, newInputBuffers, newOutputBuffers, + bvm); + return success(); + } + + finalizeBufferAllocation(b, op, newInputBuffers, newOutputBuffers, bvm); + + return success(); +} + +/// FuncOp always creates TensorToMemRef ops. +static LogicalResult convertFuncOp(OpBuilder &b, FuncOp funcOp, + BlockAndValueMapping &bvm) { + // Take a guard before anything else. + OpBuilder::InsertionGuard g(b); + b.setInsertionPointToStart(&funcOp.body().front()); + for (auto bbArg : funcOp.getArguments()) { + auto tensorType = bbArg.getType().dyn_cast<TensorType>(); + auto rankedTensorType = bbArg.getType().dyn_cast<RankedTensorType>(); + if (!tensorType) + continue; + // Cast the tensor to the most dynamic buffer possible. Further + // canonicalizations will clean up. + Type memRefType = rankedTensorType + ? getDynamicMemRefType(rankedTensorType) + : getContiguousOrUnrankedMemRefType(tensorType); + Value tensorToMemref = + b.create<memref::BufferCastOp>(funcOp.getLoc(), memRefType, bbArg); + map(bvm, bbArg, tensorToMemref); + } + return success(); +} + +/// ReturnOp always creates memref::TensorLoadOp. +static LogicalResult convertReturnOp(OpBuilder &b, ReturnOp returnOp, + BlockAndValueMapping &bvm) { + // Take a guard before anything else. + OpBuilder::InsertionGuard g(b); + b.setInsertionPoint(returnOp); + + FuncOp funcOp = cast<FuncOp>(returnOp->getParentOp()); + assert(funcOp && "only support FuncOp parent for ReturnOp"); + for (OpOperand &operand : returnOp->getOpOperands()) { + auto tensorType = operand.get().getType().dyn_cast<TensorType>(); + if (!tensorType) + continue; + operand.set(b.create<memref::TensorLoadOp>(returnOp.getLoc(), + lookup(bvm, operand.get()))); + } + return success(); +} + +static void inPlaceAnalysisFuncOpInternals(FuncOp funcOp, + const DominanceInfo &domInfo) { + if (!funcOp || funcOp->getNumRegions() == 0 || funcOp.body().empty()) + return; + + // Start propagating from FuncOp bbArgs. + destructiveUpdateAnalysis(&funcOp.body().front(), domInfo); +} + +static LogicalResult bufferizeFuncOpInternals( + FuncOp funcOp, BlockAndValueMapping &bvm, GlobalCreator &globals, + const DenseMap<FuncOp, SmallVector<int64_t>> &tiedResultsMap) { + OpBuilder b(funcOp->getContext()); + /// Start by converting `funcOp` arguments. + SmallVector<Operation *> convertedCallOps; + LogicalResult status = convertFuncOp(b, funcOp, bvm); + if (failed(status)) + return status; + funcOp.walk<WalkOrder::PreOrder>([&](Operation *op) { + status = + llvm::TypeSwitch<Operation *, LogicalResult>(op) + // Skip BufferCast and TensorLoad ops. + .Case<memref::BufferCastOp, memref::TensorLoadOp>( + [&](auto) { return success(); }) + .Case([&](LinalgOp op) { return convertAnyLinalgOp(b, op, bvm); }) + .Case([&](ReturnOp op) { return convertReturnOp(b, op, bvm); }) + .Default([&](Operation *op) { + auto isaTensor = [](Type t) { return t.isa<TensorType>(); }; + if (llvm::any_of(op->getOperandTypes(), isaTensor) || + llvm::any_of(op->getOperandTypes(), isaTensor)) + return failure(); + return success(); + }); + if (failed(status)) { + op->emitError("Failed bufferization"); + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + if (failed(status)) + return status; + + // Delete all the converted call operations. + for (Operation *op : convertedCallOps) + op->erase(); + + return status; +} + +namespace { +struct ComprehensiveFuncBufferize + : public ComprehensiveFuncBufferizeBase<ComprehensiveFuncBufferize> { + void runOnFunction(); +}; +} // end namespace + +void ComprehensiveFuncBufferize::runOnFunction() { + auto funcOp = getFunction(); + if (!funcOp || funcOp->getNumRegions() == 0 || funcOp.body().empty()) + return; + + DominanceInfo domInfo(funcOp); + GlobalCreator globals(funcOp->getParentOfType<ModuleOp>()); + BlockAndValueMapping bvm; + DenseMap<FuncOp, SmallVector<int64_t>> tiedResultsMap; + inPlaceAnalysisFuncOpInternals(funcOp, domInfo); + + LLVM_DEBUG(DBGS() << "Begin BufferizeFuncOpInternals:\n" << funcOp << "\n"); + auto guard = llvm::make_scope_exit([&] { + funcOp.walk( + [&](Operation *op) { op->removeAttr(kInPlaceResultsAttrName); }); + LLVM_DEBUG(DBGS() << "End BufferizeFuncOpInternals:\n" << funcOp << "\n"); + }); + if (failed(bufferizeFuncOpInternals(funcOp, bvm, globals, tiedResultsMap))) + signalPassFailure(); +} + +std::unique_ptr<Pass> mlir::createComprehensiveFuncBufferizePass() { + return std::make_unique<ComprehensiveFuncBufferize>(); +} diff --git a/mlir/test/Transforms/comprehensive-func-bufferize.mlir b/mlir/test/Transforms/comprehensive-func-bufferize.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Transforms/comprehensive-func-bufferize.mlir @@ -0,0 +1,44 @@ +// RUN: mlir-opt %s -comprehensive-func-bufferize -split-input-file | FileCheck %s + +// CHECK-DAG: #[[$map_2d_dyn:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)> + +// CHECK-LABEL: func @fill_inplace( +// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: tensor<?xf32> {tensor.inplaceable = true}) +func @fill_inplace(%A : tensor<?xf32> {tensor.inplaceable = true}) -> tensor<?xf32> { + // CHECK: %[[I:.*]] = memref.buffer_cast %[[A]] : memref<?xf32, #[[$map_2d_dyn]]> + + // CHECK: %[[F0:.*]] = constant 0.000000e+00 : f32 + %f0 = constant 0.0 : f32 + + // CHECK: linalg.fill(%[[I]], %[[F0]]) : memref<?xf32, #[[$map_2d_dyn]]>, f32 + %r = linalg.fill(%A, %f0) : tensor<?xf32>, f32 -> tensor<?xf32> + + // CHECK: %[[R:.*]] = memref.tensor_load %[[I]] : memref<?xf32, #[[$map_2d_dyn]]> + // CHECK: return %[[R]] : tensor<?xf32> + return %r: tensor<?xf32> +} + +// ----- + +// CHECK-DAG: #[[$map_2d_dyn:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)> + +// CHECK-LABEL: func @fill_out_of_place( +// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: tensor<?xf32>) +func @fill_out_of_place(%A : tensor<?xf32>) -> tensor<?xf32> { + // CHECK: %[[I:.*]] = memref.buffer_cast %[[A]] : memref<?xf32, #[[$map_2d_dyn]]> + + // CHECK: %[[D0:.*]] = memref.dim %[[I]], {{.*}} : memref<?xf32, #[[$map_2d_dyn]]> + // CHECK: %[[ALLOC:.*]] = memref.alloc(%[[D0]]) : memref<?xf32> + // CHECK: %[[I2:.*]] = memref.cast %[[ALLOC]] : memref<?xf32> to memref<?xf32, #map> + + // CHECK: %[[F0:.*]] = constant 0.000000e+00 : f32 + %f0 = constant 0.0 : f32 + + // CHECK: linalg.fill(%[[I2]], %[[F0]]) : memref<?xf32, #[[$map_2d_dyn]]>, f32 + %r = linalg.fill(%A, %f0) : tensor<?xf32>, f32 -> tensor<?xf32> + + // CHECK: dealloc %[[ALLOC]] : memref<?xf32> + // CHECK: %[[R:.*]] = memref.tensor_load %[[I2]] : memref<?xf32, #[[$map_2d_dyn]]> + // CHECK: return %[[R]] : tensor<?xf32> + return %r: tensor<?xf32> +}