diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td @@ -37,7 +37,17 @@ let dependentDialects = [ "AffineDialect", "StandardOpsDialect", "tensor::TensorDialect" ]; + let hasOperationAttrVerify = 1; let extraClassDeclaration = [{ + /// Attribute name used to to memoize indexing maps for named ops. + constexpr const static ::llvm::StringLiteral + kMemoizedIndexingMapsAttrName = "linalg.memoized_indexing_maps"; + + /// Attribute name used to mark region arguments that can be bufferized + /// in-place during linalg comprehensive bufferization. + constexpr const static ::llvm::StringLiteral + kInplaceableAttrName = "linalg.inplaceable"; + using RegionBuilderFunType = llvm::function_ref; RegionBuilderFunType getRegionBuilder(StringRef name) { return namedStructuredOpRegionBuilders.lookup(name); diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h @@ -23,6 +23,7 @@ #include "mlir/IR/Types.h" #include "mlir/Interfaces/CopyOpInterface.h" #include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Interfaces/InplaceInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/ViewLikeInterface.h" #include "mlir/Support/LLVM.h" diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -17,6 +17,7 @@ include "mlir/Dialect/Linalg/IR/LinalgBase.td" include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td" include "mlir/Interfaces/CopyOpInterface.td" +include "mlir/Interfaces/InplaceInterface.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" @@ -26,7 +27,9 @@ // depending on the specific Linalg op. class LinalgStructuredBase_Op props> : Op { + LinalgStructuredInterface, + InferShapedTypeOpInterface, + InplaceOpInterface])> { code structuredOpsBaseDecls = [{ // Return the number of induction variables in the basic block. This should // always be 0 for index-free linalg ops. For IndexedGeneric, this must be @@ -48,6 +51,31 @@ return cast(getOperation()).reifyReturnTypeShapesPerResultDim(b, reifiedReturnShapes); } + + // Return the OpResult that is tied to an operand. + OpResult getTiedOpResult(OpOperand &opOperand) { + if (!opOperand.get().getType().isa()) + return OpResult(); + // For now assume inputs are never inplaceable. + // TODO: refine this. + if (opOperand.getOperandNumber() < getNumInputs()) + return OpResult(); + // For now assume if the operand appears twice, it is not inplaceable. + // TODO: refine this. + for (auto &opOperand2 : getOperation()->getOpOperands()) { + if (opOperand.getOperandNumber() == opOperand2.getOperandNumber()) + continue; + if (opOperand.get() == opOperand2.get()) + return OpResult(); + } + int64_t outputOperandIndex = + opOperand.getOperandNumber() - getNumInputs(); + int64_t numOutputBuffers = 0; + for (unsigned idx = 0; idx < outputOperandIndex; ++idx) + if (!getOutputShapedType(idx).isa()) + ++numOutputBuffers; + return getOperation()->getResult(outputOperandIndex - numOutputBuffers); + } }]; } diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h --- a/mlir/include/mlir/Dialect/Linalg/Passes.h +++ b/mlir/include/mlir/Dialect/Linalg/Passes.h @@ -53,6 +53,12 @@ /// Placeholder for now, this is NYI. std::unique_ptr> createConvertLinalgToAffineLoopsPass(); +/// Create a pass that bufferizes the body of a FuncOp and tries to reuse the +/// buffers for those arguments that: +/// a) have been annotated 'inplaceable' and +/// b) whose buffer uses would be free of memory hazards. +std::unique_ptr createLinalgComprehensiveFuncBufferizePass(); + /// Create a pass to convert Linalg operations which work on tensors to use /// buffers instead. std::unique_ptr> createLinalgBufferizePass(); diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -22,6 +22,21 @@ let dependentDialects = ["linalg::LinalgDialect", "memref::MemRefDialect"]; } +def LinalgComprehensiveFuncBufferize : + FunctionPass<"linalg-comprehensive-func-bufferize"> { + let summary = "Bufferize (tensor into memref) the body of a FuncOp and try " + "to reuse the buffers for those arguments that " + "a) have been annotated 'inplaceable' and " + "b) whose buffer uses would be free of memory hazards"; + let description = [{ + This pass implements a cross-dialect bufferization approach and performs an + analysis to determine which op operands and results may be bufferized in the + same buffers. The analysis is performed on SSA use-def chains starting from + function operands that are annotated with the 'inplaceable' attribute + }]; + let constructor = "mlir::createLinalgComprehensiveFuncBufferizePass()"; +} + def LinalgFoldUnitExtentDims : FunctionPass<"linalg-fold-unit-extent-dims"> { let summary = "Remove unit-extent dimension in Linalg ops on tensors"; let constructor = "mlir::createLinalgFoldUnitExtentDimsPass()"; diff --git a/mlir/include/mlir/Interfaces/InplaceInterface.h b/mlir/include/mlir/Interfaces/InplaceInterface.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Interfaces/InplaceInterface.h @@ -0,0 +1,24 @@ +//===- InplaceInterface.h - Inplace interface -----------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements interfaces that specify which tensor results may fold +// onto operands to enable inplace bufferization. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_INTERFACES_INPLACEINTERFACE_H_ +#define MLIR_INTERFACES_INPLACEINTERFACE_H_ + +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/OpDefinition.h" + +/// Include the generated interface declarations. +#include "mlir/Interfaces/InplaceInterface.h.inc" + +#endif // MLIR_INTERFACES_INPLACEINTERFACE_H_ diff --git a/mlir/include/mlir/Interfaces/InplaceInterface.td b/mlir/include/mlir/Interfaces/InplaceInterface.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Interfaces/InplaceInterface.td @@ -0,0 +1,43 @@ +//===- InplaceInterfaces.td - Inplace interface ------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Defines the interface that specifies which tensor results may fold onto +// tensor operands to enable inplace bufferization. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_INTERFACES_INPLACEINTERFACE +#define MLIR_INTERFACES_INPLACEINTERFACE + +include "mlir/IR/OpBase.td" + +def InplaceOpInterface : OpInterface<"InplaceOpInterface"> { + let description = [{ + Encodes properties of an operation that specifies which tensor results may + fold onto tensor operands to enable inplace bufferization. + }]; + let cppNamespace = "::mlir"; + + let methods = [ + InterfaceMethod< + /*desc=*/[{ + Return the OpResult that is tied to an operand. The OpResult has the + same type as `opOperand`. Used to specify which tensor result may fold + onto `opOperand` during bufferization. + Used to describe ops whose buffer variant have destructive update + semantics (i.e. inplace update). + Return null if op lacks destructive update semantics for `opOperand`. + }], + /*retTy=*/"OpResult", + /*methodName=*/"getTiedOpResult", + /*args=*/(ins "OpOperand&":$opOperand) + >, + ]; +} + +#endif // MLIR_INTERFACES_INPLACEINTERFACE diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp @@ -15,6 +15,7 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/FunctionSupport.h" #include "mlir/Parser.h" #include "mlir/Support/LLVM.h" #include "mlir/Transforms/InliningUtils.h" @@ -57,6 +58,14 @@ // LinalgDialect //===----------------------------------------------------------------------===// +/// Attribute name used to to memoize indexing maps for named ops. +constexpr const ::llvm::StringLiteral + LinalgDialect::kMemoizedIndexingMapsAttrName; + +/// Attribute name used to mark region arguments that can be bufferized +/// in-place during linalg comprehensive bufferization. +constexpr const ::llvm::StringLiteral LinalgDialect::kInplaceableAttrName; + /// Trait to check if T provides a `regionBuilder` method. template using has_region_builder = decltype(T::regionBuilder); @@ -131,3 +140,21 @@ DialectAsmPrinter &os) const { print(type.cast(), os); } + +LogicalResult LinalgDialect::verifyOperationAttribute(Operation *op, + NamedAttribute attr) { + if (attr.first == LinalgDialect::kInplaceableAttrName) { + if (!attr.second.isa()) { + return op->emitError() << "'" << LinalgDialect::kInplaceableAttrName + << "' is expected to be a boolean attribute"; + } + if (!op->hasTrait()) + return op->emitError() << "expected " << attr.first + << " to be used on function-like operations"; + return success(); + } + if (attr.first == LinalgDialect::kMemoizedIndexingMapsAttrName) + return success(); + return op->emitError() << "attribute '" << attr.first + << "' not supported by the linalg dialect"; +} diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferization.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferization.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferization.cpp @@ -0,0 +1,779 @@ +//===- ComprehensiveBufferization.cpp - Single pass bufferization ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Perform inplace bufferization within function boundaries. +// This is a specialized pass that supports inplace analysis for a fixed subset +// of ops that have well-defined inplace semantics. +// This pass caters to high-performance codegen where buffer reuse is deemed +// necessary: the pass should fail if the bufferized form of the function needs +// to return any buffer. +// Generic control-flow and branching are unsupported. +// Composability with extensible set of ops is not a first-class concern. +// +// Bufferization occurs by: +// a. performing an inPlace analysis `inPlaceAnalysisFuncOpInternals` +// which marks each operation within the function with the +// `kInPlaceResultsAttrName` attribute. +// b. traversing each operation in the function and rewriting it in +// buffer form and keeping a BlockAndValueMapping mapping of the +// rewrites. New allocations are introduced during this step. +// TODO: Allocation + depending op hoisting to outermost enclosing +// sequential scope. +// c. at the end of this bufferization, 2 cases may occur: +// * inplaceable function arguments may be reused in place after the +// function itself has been bufferized. This is encoded by IR resembling: +// +// ``` +// #map = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)> +// func @foo(%A: tensor {linalg.inplaceable = true}) -> tensor { +// %0 = memref.buffer_cast %A : memref +// // ... uses of %0 +// %res = memref.tensor_load %0 : memref +// return %res : tensor +// } +// ``` +// +// this is the cue for the bufferization of the function foo (and calls to +// it) may bufferize to `func @foo(%A: memref)`. +// To fully achieve bufferization, an additional analysis is needed to +// determine whether function argument/operand pairs bufferize to a single +// inplace buffer argument (i.e. functions may return tensors in arbitrary +// order that may not match argument numbers). +// * results that don't map to an inplaceable function argument must be +// allocated. Since memref semantics wrt ownership of the underlying +// memory region are not well-defined, comprehensive bufferization chooses +// to perform allocations in a scoped fashion: returning memrefs is always +// considered illegal. Such scenarios are encoded by IR resembling: +// +// ``` +// #map = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)> +// func @foo(%A: tensor {linalg.inplaceable = true}) -> tensor { +// %0 = memref.buffer_cast %A : memref +// %1 = memref.dim %0, %c0 : memref +// %2 = memref.alloc(%1) : memref +// %3 = memref.cast %2 : memref to memref +// // ... uses of %3 +// memref.dealloc %2 : memref +// %res = memref.tensor_load %3 : memref +// return %res : tensor +// } +// ``` +// +// this is the cue for the bufferization of the function foo (and calls to +// it) that it must bufferize to +// `func @foo(%A: memref, +// %B: memref)` (i.e. make a cloned +// allocation of the result tensor) +// To fully achieve bufferization, the alloc/dealloc pair must be lifted +// out of the function at each call site. +// +// Lastly, note that layout map chosen to bufferize is the most dynamic +// canonical strided layout of the proper rank. This ensures compatibility with +// expected layouts after transformations. Combinations of memref.cast + +// canonicalization are responsible for clean ups. + +#include "PassDetail.h" +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/Operation.h" +#include "mlir/Interfaces/InplaceInterface.h" +#include "mlir/Interfaces/LoopLikeInterface.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/BufferUtils.h" + +#include "llvm/ADT/ScopeExit.h" +#include "llvm/ADT/TypeSwitch.h" + +#define DEBUG_TYPE "comprehensive-func-bufferize" + +using namespace mlir; +using namespace linalg; +using namespace tensor; + +#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") + +//===----------------------------------------------------------------------===// +// Bufferization-specific attribute manipulation. +//===----------------------------------------------------------------------===// + +/// Attribute marker to specify op results that can be bufferized inPlace. +constexpr StringLiteral kInPlaceResultsAttrName = "__inplace_results_attr__"; + +// TODO: proper enum. +enum class InPlaceSpec { + False, + True, + None, +}; + +static StringRef stringify(InPlaceSpec val) { + switch (val) { + case InPlaceSpec::False: + return "false"; + case InPlaceSpec::True: + return "true"; + case InPlaceSpec::None: + return "none"; + } + return ""; +} + +static Optional symbolize(StringRef str) { + return StringSwitch>(str) + .Case("false", InPlaceSpec::False) + .Case("true", InPlaceSpec::True) + .Case("none", InPlaceSpec::None) + .Default(None); +} + +/// Mark whether OpResult can actually be bufferized inplace. If `inPlace` is +/// `InPlaceSpec::True`, the use-def chain analysis has guaranteed that no +/// subsequent write would occur to the bufferized tensor value (i.e. the result +/// can be bufferized inPlace). +static void setInPlaceOpResult(OpResult opResult, + InPlaceSpec inPlace = InPlaceSpec::True) { + if (!opResult) + return; + + Operation *op = opResult.getOwner(); + auto attr = + op->getAttr(kInPlaceResultsAttrName).dyn_cast_or_null(); + SmallVector inPlaceVector = + attr ? SmallVector( + llvm::to_vector<4>(attr.getAsValueRange())) + : SmallVector(op->getNumResults(), + stringify(InPlaceSpec::None)); + LLVM_DEBUG(DBGS() << "Set inPlace=" << stringify(inPlace) << ": " << *op + << " @idx=" << opResult.getResultNumber() << "\n"); + inPlaceVector[opResult.getResultNumber()] = stringify(inPlace); + op->setAttr(kInPlaceResultsAttrName, + OpBuilder(op).getStrArrayAttr(inPlaceVector)); +} + +/// Get the InPlaceSpec attribute entry `kInPlaceResultsAttrName` for +/// `opResult`. If the result is `InPlaceSpec::True`, the use-def chain analysis +/// has guaranteed that no subsequent read of the tensor value occurs and the +/// result can be buferized inPlace. +/// If no InPlaceSpec attribute has been set for `opResult`, return +/// InPlaceSpec::None. +static InPlaceSpec getInPlace(OpResult opResult) { + if (!opResult) + return InPlaceSpec::None; + + Operation *op = opResult.getOwner(); + auto attr = + op->getAttr(kInPlaceResultsAttrName).dyn_cast_or_null(); + if (!attr) + return InPlaceSpec::None; + + // Must return a proper value. + return *symbolize(*(attr.getAsValueRange().begin() + + opResult.getResultNumber())); +} + +/// Get inPlace information for `bbArg`. +/// If it does not come from a function, return InPlaceSpec::False. +static InPlaceSpec getInPlace(BlockArgument bbArg) { + auto funcOp = dyn_cast(bbArg.getOwner()->getParentOp()); + if (!funcOp) + return InPlaceSpec::False; + auto attr = funcOp.getArgAttrOfType( + bbArg.getArgNumber(), LinalgDialect::kInplaceableAttrName); + if (!attr) + return InPlaceSpec::None; + return attr.getValue() ? InPlaceSpec::True : InPlaceSpec::False; +} + +//===----------------------------------------------------------------------===// +// Bufferization-specific BlockAndValueMapping support with debugging. +//===----------------------------------------------------------------------===// + +/// Wrapper for better debugging. +static void map(BlockAndValueMapping &bvm, ValueRange keys, ValueRange values) { + assert(!keys.empty() && "Unexpected empty keys"); + LLVM_DEBUG(DBGS() << "Map: " << keys.front() << " to " << values.front() + << "\n"); + return bvm.map(keys, values); +} + +/// Wrapper for better debugging. +static void map(BlockAndValueMapping &bvm, Value key, Value value) { + LLVM_DEBUG(DBGS() << "Map: " << key << " to " << value << "\n"); + return bvm.map(key, value); +} + +/// Wrapper for better debugging. +static Value lookup(BlockAndValueMapping &bvm, Value key) { + // TODO: if key comes from bbArg, forward. + assert(key.getType().isa()); + if (!bvm.lookupOrNull(key)) { + if (auto bbArg = key.dyn_cast()) { + if (isa(key.getParentBlock()->getParentOp())) + key.getParentBlock()->getParentOp()->dump(); + else + key.getParentBlock()->getParentOp()->getParentOfType()->dump(); + bbArg.getOwner()->getParentOp()->dump(); + } else { + key.getDefiningOp()->getParentOfType()->dump(); + } + llvm::errs() << "NO VALUE FOR KEY: " << key << "\n"; + abort(); + } + return bvm.lookup(key); +} + +//===----------------------------------------------------------------------===// +// Bufferization-specific support. +//===----------------------------------------------------------------------===// + +/// Determine whether any subsequent read of the tensor `opOperand` may occur. +/// For now, this assumes any use is a read. If any use of the tensor does not +/// properly dominate `opOperand.getOwner()`, then the tensor cannot be +/// bufferized inPlace. +// TODO: For now, this assumes any use is a read. Refine this. +bool hasInterferingTensorRead(OpOperand &opOperand, + const DominanceInfo &domInfo) { + if (!opOperand.get().getType().isa()) + return false; + for (auto &use : opOperand.get().getUses()) { + Operation *user = use.getOwner(); + + // If properly dominate, there is a clear sequence point and we can dismiss + // read. + if (domInfo.properlyDominates(user, opOperand.getOwner())) + continue; + // Otherwise, we need to analyze self-dependencies, for now just let it go. + // TODO: proper self-dependence analysis. + if (domInfo.dominates(user, opOperand.getOwner())) + continue; + if (user == opOperand.getOwner() && + use.getOperandNumber() == opOperand.getOperandNumber()) + continue; + LLVM_DEBUG(DBGS() << "found interfering read operand #" + << opOperand.getOperandNumber() + << " in op: " << *opOperand.getOwner() << "\n"); + return true; + } + LLVM_DEBUG(DBGS() << "no interfering read\n"); + return false; +} + +/// Return false if either: +/// 1. `opOperand` is produced by a constant op. For now this is assumed to be +/// bufferized to a GlobalMemrefOp that cannot be written. Generalize in the +/// future. +/// 2.`opOperand` is a BlockArgument of a FuncOp that is not known to be +/// bufferizable inplace. +/// 3.`opOperand` has an interfering tensor read. +/// Return true otherwise. +bool isBufferizableInPlace(OpOperand &opOperand, const DominanceInfo &domInfo) { + // Constant tensors are deemed not bufferizable for now. + if (auto constantOp = + dyn_cast_or_null(opOperand.get().getDefiningOp())) + return !constantOp.getResult().getType().isa(); + if (auto bbArg = opOperand.get().dyn_cast()) { + // Function argument that may not be written-to needs to be copied by + // user(s). + // TODO: better propagate the fact that we want a single clone inside the + // function. Atm every user that wants to write inplace will create its own + // alloc, irrespective of whether or not interfering reads occur. + if (isa(bbArg.getOwner()->getParentOp())) + return getInPlace(bbArg) == InPlaceSpec::True; + } + return !hasInterferingTensorRead(opOperand, domInfo); +} + +//===----------------------------------------------------------------------===// +// Bufferization-specific MemRefType support. +//===----------------------------------------------------------------------===// + +/// Return a contiguous MemRefType (i.e. with canonical/empty layout map) with +/// the same shape as `shapedType` and specified `layout` and `addressSpace`. +static MemRefType getContiguousMemRefType(ShapedType shapedType, + ArrayRef layout = {}, + unsigned addressSpace = 0) { + if (RankedTensorType tensorType = shapedType.dyn_cast()) + return MemRefType::get(tensorType.getShape(), tensorType.getElementType(), + layout, addressSpace); + MemRefType memrefType = shapedType.cast(); + return MemRefType::get(memrefType.getShape(), memrefType.getElementType(), + layout, addressSpace); +} + +/// Return a contiguous MemRefType (i.e. with canonical/empty layout map) with +/// the same shape as `shapedType` and specified `layout` and `addressSpace` or +/// an UnrankedMemRefType otherwise. +static Type getContiguousOrUnrankedMemRefType(Type type, + ArrayRef layout = {}, + unsigned addressSpace = 0) { + if (type.isa()) + return getContiguousMemRefType(type.cast(), layout, + addressSpace); + assert(layout.empty() && "expected empty layout with UnrankedMemRefType"); + return UnrankedMemRefType::get(getElementTypeOrSelf(type), addressSpace); +} + +/// Return a MemRefType to which the `tensorType` can be bufferized in a +/// composable fashion. The layout must be the most dynamic possible and +/// canonicalize away once bufferization is finished. +static MemRefType getDynamicMemRefType(RankedTensorType tensorType, + unsigned addressSpace = 0) { + // TODO: address space decisions to connect with the actual alloc. + int64_t dynamicOffset = ShapedType::kDynamicStrideOrOffset; + SmallVector dynamicStrides(tensorType.getRank(), + ShapedType::kDynamicStrideOrOffset); + AffineMap stridedLayout = makeStridedLinearLayoutMap( + dynamicStrides, dynamicOffset, tensorType.getContext()); + return MemRefType::get(tensorType.getShape(), tensorType.getElementType(), + stridedLayout, addressSpace); +} + +//===----------------------------------------------------------------------===// +// Bufferization-specific inPlace pattern matching support. +//===----------------------------------------------------------------------===// + +/// First assign `op` if `slice.back()` isa `T`, then check condition. +/// If anything fails just return failure. Otherwise update `sliceRef` by +/// dropping `sliceRef.back()`, then return success(). +template +static LogicalResult +matchAndDropBack(ArrayRef &sliceRef, T &op, + llvm::function_ref condition = nullptr) { + if (sliceRef.empty()) + return failure(); + op = dyn_cast(sliceRef.back()); + if (!op || (condition && failed(condition(op)))) + return failure(); + sliceRef = sliceRef.drop_back(); + return success(); +} + +//===----------------------------------------------------------------------===// +// Bufferization-specific scoped alloc/dealloc insertion support. +//===----------------------------------------------------------------------===// + +/// Create an Allocop/DeAllocOp pair, where the AllocOp is after +/// `shapedValue.getDefiningOp` (or at the top of the block in case of a bbArg) +/// and the DeallocOp is at the end of the block. +static Value createNewAllocDeallocPairForShapedValue( + OpBuilder &b, Location loc, Value shapedValue, + SmallVector dynOperands = {}) { + // Take a guard before anything else. + OpBuilder::InsertionGuard g(b); + + // TODO: non-zero address space. + // TODO: layout information if relevant. + // Cannot allocate an unranked memref so just always go for the contiguous + // form. + MemRefType allocMemRefType = + getContiguousMemRefType(shapedValue.getType().cast()); + assert(shapedValue.getType().isa()); + MemRefType memRefType = shapedValue.getType().dyn_cast(); + memRefType = memRefType ? memRefType : allocMemRefType; + + if (auto bbArg = shapedValue.dyn_cast()) { + b.setInsertionPointToStart(bbArg.getOwner()); + loc = bbArg.getOwner()->getParentOp()->getLoc(); + } else { + b.setInsertionPointAfter(shapedValue.getDefiningOp()); + loc = shapedValue.getDefiningOp()->getLoc(); + } + + // If the dynOperands are not passed explicity, copmpute them. + // This circumvents currently missing dim(init_tensor) canonicalizations. + // TODO: dim(init_tensor) canonicalization. + if (dynOperands.empty()) { + for (auto dim : llvm::enumerate(memRefType.getShape())) + if (dim.value() == ShapedType::kDynamicSize) + dynOperands.push_back( + b.create(loc, shapedValue, dim.index())); + } + + Value allocated = + b.create(loc, allocMemRefType, dynOperands); + Value casted = allocated; + if (memRefType != allocMemRefType) + casted = b.create(loc, memRefType, allocated); + b.setInsertionPoint(allocated.getParentBlock()->getTerminator()); + b.create(loc, allocated); + return casted; +} + +//===----------------------------------------------------------------------===// +// Bufferization-specific inPlace analysis support. +//===----------------------------------------------------------------------===// + +/// Detect the simple terminator pattern: +/// ``` +/// candidate -> single-result with single-use linalg op -> term +/// ``` +template +static LogicalResult +detectInplaceOpToTerminator(Operation *parentOp, BlockArgument candidate, + ArrayRef &sliceRef) { + assert(parentOp && "Unexpected null parent op"); + if (!isa(parentOp)) + return failure(); + + ArrayRef tmpSliceRef = sliceRef; + + TerminatorOp terminatorOp; + // Match returnOp and update tmpSliceRef. + if (failed(matchAndDropBack(tmpSliceRef, terminatorOp))) { + LLVM_DEBUG(DBGS() << "FAIL: inplaceOpToTerm pattern -> slice must end with " + "a known terminator\n"); + return failure(); + } + + InplaceOpInterface inplaceOp; + if (failed(matchAndDropBack(tmpSliceRef, inplaceOp))) { + LLVM_DEBUG(DBGS() << "FAIL: inplaceOpToTerm pattern -> slice must end with " + "an inplaceOp\n"); + return failure(); + } + OpResult res; + for (auto &opOperand : inplaceOp->getOpOperands()) { + if (opOperand.get() != candidate) + continue; + res = inplaceOp.getTiedOpResult(opOperand); + break; + } + if (!res) { + LLVM_DEBUG(DBGS() << "FAIL: inplaceOpToTerm pattern -> slice must end with " + "inplaceable use into some OpResult\n"); + return failure(); + } + + // Commit what has been detected. + setInPlaceOpResult(res); + sliceRef = tmpSliceRef; + LLVM_DEBUG(DBGS() << "SUCCESS: inplaceOpToTerm pattern\n"); + + return success(); +} + +/// The following uses internal knowledge of the position of tied operand / +/// results. +// TODO: TieOperandInterface. +static void propagateInPlace(const SmallVector &initalWorklist, + const DominanceInfo &domInfo) { + LLVM_DEBUG(DBGS() << "\n\n"); + LLVM_DEBUG(DBGS() << "Start propagateInPlace from initial WL\n"); + for (OpOperand *operand : initalWorklist) + LLVM_DEBUG(DBGS() << "WL item: " << operand->get() << " used by " + << *operand->getOwner() << "\n"); + SmallVector worklist(initalWorklist); + for (unsigned idx = 0; idx < worklist.size(); ++idx) { + // TODO: bail on subtensor/subtensor_insert and vector.transfer_read/write + // that should have been already captured in destructive update patterns? + OpOperand &operand = *worklist[idx]; + LLVM_DEBUG(DBGS() << "WL item: " << *operand.getOwner() << "\n"); + // If the owner turns out to be a CallOp without + // `kWriteableFuncBufferArgsAttrName` this will be a noop. + if (auto inplaceOp = dyn_cast(operand.getOwner())) { + if (isBufferizableInPlace(operand, domInfo)) { + LLVM_DEBUG(DBGS() << "bufferizable inplace\n"); + setInPlaceOpResult(inplaceOp.getTiedOpResult(operand)); + } + } + LLVM_DEBUG(DBGS() << "propagatedInPlace: " << *operand.getOwner() << "\n"); + // use can have interfering reads that prevent it from being written inPlace + // but the values it produces are still themselves candidates for inPlace at + // their point of use. + for (Value v : operand.getOwner()->getResults()) { + LLVM_DEBUG(DBGS() << "propagate result: " << v << "\n"); + for (auto &use : v.getUses()) { + LLVM_DEBUG(DBGS() << "add use to WL: " << use.get() << "\n"); + worklist.push_back(&use); + } + } + } + LLVM_DEBUG(DBGS() << "\n\n"); +} + +static void propagateInPlace(OpOperand &opOperand, + const DominanceInfo &domInfo) { + SmallVector worklist{&opOperand}; + propagateInPlace(worklist, domInfo); +} + +static void propagateInPlace(BlockArgument &bbArg, + const DominanceInfo &domInfo) { + SmallVector worklist; + for (auto &use : bbArg.getUses()) + worklist.push_back(&use); + propagateInPlace(worklist, domInfo); +} + +/// Iterate over bbArgs of `parentOp` and determine if they are the root of a +/// known destructive update chain. Such a destructive update is related to +/// traditional loop nest + memory analysis but provides a simpler SSA use-def +/// chain-based abstraction. +static void destructiveUpdateAnalysis(Block *block, + const DominanceInfo &domInfo) { + Operation *parentOp = block->getParentOp(); + for (BlockArgument candidate : block->getArguments()) { + LLVM_DEBUG(llvm::dbgs() << "\n\n"); + LLVM_DEBUG(DBGS() << "Destructive update analysis on candidate: " + << candidate << "\nof:\n" + << *parentOp << "\n"); + + if (!candidate.getType().isa()) { + LLVM_DEBUG(DBGS() << "Not a tensor\n"); + continue; + } + + // FuncOp arguments must be inplaceable otherwise they cannot be the root of + // a destructive update chain. + if (isa(parentOp) && getInPlace(candidate) != InPlaceSpec::True) { + LLVM_DEBUG(DBGS() << "Not inplace\n"); + continue; + } + + llvm::SetVector slice; + getForwardSlice(candidate, &slice, + [&](Operation *op) { return op->getBlock() == block; }); + + LLVM_DEBUG(DBGS() << "Slice:\n"); + for (auto *op : slice) + LLVM_DEBUG(DBGS() << *op << "\n"); + + ArrayRef sliceRef = slice.getArrayRef(); + bool failedDetectingDestructiveUpdate = + // func / return inplace patterns. + failed(detectInplaceOpToTerminator( + parentOp, candidate, sliceRef)); + if (failedDetectingDestructiveUpdate) { + LLVM_DEBUG(DBGS() << "Failed to detect a destructive update pattern\n"); + continue; + } + + propagateInPlace(candidate, domInfo); + } +} + +//===----------------------------------------------------------------------===// +// Bufferization as simple BlockAndValueMapping rewrites. +//===----------------------------------------------------------------------===// + +/// Helper function for LinalgOp bufferization. +/// Operate on mixed tensor + buffer Linalg ops for progressive bufferization. +/// Allocate the output buffers for the remaining tensor output operands of +/// the Linalg op. If the tensor is an "init" tensor (i.e. its value is +/// actually used in the payload region), we additionally copy the original +/// value into the newly allocated buffer. +static void allocateBuffersForResults(OpBuilder &b, Location loc, LinalgOp op, + SmallVectorImpl &resultBuffers, + BlockAndValueMapping &bvm) { + // Take a guard before anything else. + OpBuilder::InsertionGuard g(b); + + // Lazily compute loopRanges. + SmallVector loopRanges; + + // Linalg invariant: output tensors and result match 1-1. + assert(op.getNumOutputTensors() == op->getNumResults()); + for (auto &opOperand : op.getOutputOpOperands()) { + Value output = opOperand.get(); + if (output.getType().isa()) { + resultBuffers.push_back(output); + continue; + } + + // If output tensor is marked inPlace, just use the buffer. + // The following uses internal knowledge of the position of tied operand / + // results. A proper TieOperandInterface would be much better. + auto inplaceOp = cast(op.getOperation()); + OpResult tiedResult = inplaceOp.getTiedOpResult(opOperand); + if (getInPlace(tiedResult) == InPlaceSpec::True) { + resultBuffers.push_back(lookup(bvm, output)); + continue; + } + + Value dimTensor = bvm.lookupOrDefault(output); + Value alloc = createNewAllocDeallocPairForShapedValue(b, loc, dimTensor); + b.setInsertionPointAfter(alloc.getDefiningOp()); + resultBuffers.push_back(alloc); + + // Additionally, if the output buffer is used, clone its value for now. + if (op.payloadUsesValueFromOpOperand(&opOperand)) + b.create(loc, lookup(bvm, output), alloc); + } + if (op->getNumResults()) + map(bvm, op->getResults(), resultBuffers); +} + +static void finalizeBufferAllocation(OpBuilder &b, LinalgOp op, + ValueRange inputs, ValueRange outputs, + BlockAndValueMapping &bvm) { + SmallVector newOperands = inputs; + newOperands.append(outputs.begin(), outputs.end()); + auto otherOperands = op.getAssumedNonShapedOperands(); + newOperands.append(otherOperands.begin(), otherOperands.end()); + Location loc = op.getLoc(); + op.clone(b, loc, /*resultTypes=*/TypeRange{}, newOperands); + + // Replace the results of the old op with the new output buffers. + if (op->getNumResults()) + map(bvm, op->getResults(), outputs); + if (!op.hasTensorSemantics()) + op->erase(); +} + +/// Generic conversion for any LinalgOp. +/// Operate on mixed tensor + buffer Linalg ops for progressive bufferization. +static LogicalResult convertAnyLinalgOp(OpBuilder &b, LinalgOp op, + BlockAndValueMapping &bvm) { + // Take a guard before anything else. + OpBuilder::InsertionGuard g(b); + + if (op.hasBufferSemantics()) + return failure(); + + LLVM_DEBUG(DBGS() << "convert: " << *op << "\n"); + + b.setInsertionPoint(op); + Location loc = op.getLoc(); + SmallVector newInputBuffers; + newInputBuffers.reserve(op.getNumInputs()); + for (Value v : op.getInputs()) + newInputBuffers.push_back(lookup(bvm, v)); + SmallVector newOutputBuffers; + allocateBuffersForResults(b, loc, op, newOutputBuffers, bvm); + finalizeBufferAllocation(b, op, newInputBuffers, newOutputBuffers, bvm); + return success(); +} + +/// DimOp tensor operand is modified inplace. This allows leaving dead tensors +/// behind that will get DCE'd. +static LogicalResult convertDimOp(OpBuilder &b, memref::DimOp dimOp, + BlockAndValueMapping &bvm) { + if (dimOp.memrefOrTensor().getType().isa()) + dimOp.memrefOrTensorMutable().assign(lookup(bvm, dimOp.memrefOrTensor())); + return success(); +} + +/// FuncOp always creates TensorToMemRef ops. +static LogicalResult convertFuncOp(OpBuilder &b, FuncOp funcOp, + BlockAndValueMapping &bvm) { + // Take a guard before anything else. + OpBuilder::InsertionGuard g(b); + b.setInsertionPointToStart(&funcOp.body().front()); + for (auto bbArg : funcOp.getArguments()) { + auto tensorType = bbArg.getType().dyn_cast(); + if (!tensorType) + continue; + auto rankedTensorType = tensorType.dyn_cast(); + // Cast the tensor to the most dynamic buffer possible. Further + // canonicalizations will clean up. + Type memRefType = rankedTensorType + ? getDynamicMemRefType(rankedTensorType) + : getContiguousOrUnrankedMemRefType(tensorType); + Value tensorToMemref = + b.create(funcOp.getLoc(), memRefType, bbArg); + map(bvm, bbArg, tensorToMemref); + } + return success(); +} + +/// ReturnOp always creates memref::TensorLoadOp. +static LogicalResult convertReturnOp(OpBuilder &b, ReturnOp returnOp, + BlockAndValueMapping &bvm) { + // Take a guard before anything else. + OpBuilder::InsertionGuard g(b); + b.setInsertionPoint(returnOp); + + FuncOp funcOp = cast(returnOp->getParentOp()); + assert(funcOp && "only support FuncOp parent for ReturnOp"); + for (OpOperand &operand : returnOp->getOpOperands()) { + auto tensorType = operand.get().getType().dyn_cast(); + if (!tensorType) + continue; + operand.set(b.create(returnOp.getLoc(), + lookup(bvm, operand.get()))); + } + return success(); +} + +static void inPlaceAnalysisFuncOpInternals(FuncOp funcOp, + const DominanceInfo &domInfo) { + assert(funcOp && funcOp->getNumRegions() > 0 || + !funcOp.body().empty() && "expected a funcOp definition with a body"); + + // Start propagating from FuncOp bbArgs. + destructiveUpdateAnalysis(&funcOp.body().front(), domInfo); +} + +static LogicalResult bufferizeFuncOpInternals( + FuncOp funcOp, BlockAndValueMapping &bvm, + const DenseMap> &tiedResultsMap) { + OpBuilder b(funcOp->getContext()); + /// Start by converting `funcOp` arguments. + if (failed(convertFuncOp(b, funcOp, bvm))) + return failure(); + WalkResult result = funcOp.walk([&](Operation *op) { + LogicalResult status = + llvm::TypeSwitch(op) + // Skip BufferCast and TensorLoad ops. + .Case( + [&](auto) { return success(); }) + .Case([&](memref::DimOp op) { return convertDimOp(b, op, bvm); }) + .Case([&](LinalgOp op) { return convertAnyLinalgOp(b, op, bvm); }) + .Case([&](ReturnOp op) { return convertReturnOp(b, op, bvm); }) + .Default([&](Operation *op) { + auto isaTensor = [](Type t) { return t.isa(); }; + if (llvm::any_of(op->getOperandTypes(), isaTensor) || + llvm::any_of(op->getOperandTypes(), isaTensor)) + return failure(); + return success(); + }); + if (failed(status)) { + op->emitError("Failed bufferization"); + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + if (result.wasInterrupted()) + return failure(); + return success(); +} + +namespace { +struct LinalgComprehensiveFuncBufferize + : public LinalgComprehensiveFuncBufferizeBase< + LinalgComprehensiveFuncBufferize> { + void runOnFunction(); + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } +}; +} // end namespace + +void LinalgComprehensiveFuncBufferize::runOnFunction() { + auto funcOp = getFunction(); + DominanceInfo domInfo(funcOp); + BlockAndValueMapping bvm; + DenseMap> tiedResultsMap; + inPlaceAnalysisFuncOpInternals(funcOp, domInfo); + + LLVM_DEBUG(DBGS() << "Begin BufferizeFuncOpInternals:\n" << funcOp << "\n"); + auto guard = llvm::make_scope_exit([&] { + funcOp.walk( + [&](Operation *op) { op->removeAttr(kInPlaceResultsAttrName); }); + LLVM_DEBUG(DBGS() << "End BufferizeFuncOpInternals:\n" << funcOp << "\n"); + }); + if (failed(bufferizeFuncOpInternals(funcOp, bvm, tiedResultsMap))) + signalPassFailure(); +} + +std::unique_ptr mlir::createLinalgComprehensiveFuncBufferizePass() { + return std::make_unique(); +} diff --git a/mlir/lib/Interfaces/InplaceInterface.cpp b/mlir/lib/Interfaces/InplaceInterface.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Interfaces/InplaceInterface.cpp @@ -0,0 +1,17 @@ +//===- InplaceInterface.cpp -----------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Interfaces/InplaceInterface.h" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// Table-generated class definitions +//===----------------------------------------------------------------------===// + +#include "mlir/Interfaces/InplaceInterface.cpp.inc" diff --git a/mlir/test/Dialect/Linalg/comprehensive-func-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-func-bufferize.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/comprehensive-func-bufferize.mlir @@ -0,0 +1,44 @@ +// RUN: mlir-opt %s -linalg-comprehensive-func-bufferize -split-input-file | FileCheck %s + +// CHECK-DAG: #[[$map_2d_dyn:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)> + +// CHECK-LABEL: func @fill_inplace( +// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: tensor {linalg.inplaceable = true}) +func @fill_inplace(%A : tensor {linalg.inplaceable = true}) -> tensor { + // CHECK: %[[I:.*]] = memref.buffer_cast %[[A]] : memref + + // CHECK: %[[F0:.*]] = constant 0.000000e+00 : f32 + %f0 = constant 0.0 : f32 + + // CHECK: linalg.fill(%[[I]], %[[F0]]) : memref, f32 + %r = linalg.fill(%A, %f0) : tensor, f32 -> tensor + + // CHECK: %[[R:.*]] = memref.tensor_load %[[I]] : memref + // CHECK: return %[[R]] : tensor + return %r: tensor +} + +// ----- + +// CHECK-DAG: #[[$map_2d_dyn:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)> + +// CHECK-LABEL: func @fill_out_of_place( +// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: tensor) +func @fill_out_of_place(%A : tensor) -> tensor { + // CHECK: %[[I:.*]] = memref.buffer_cast %[[A]] : memref + + // CHECK: %[[D0:.*]] = memref.dim %[[I]], {{.*}} : memref + // CHECK: %[[ALLOC:.*]] = memref.alloc(%[[D0]]) : memref + // CHECK: %[[I2:.*]] = memref.cast %[[ALLOC]] : memref to memref + + // CHECK: %[[F0:.*]] = constant 0.000000e+00 : f32 + %f0 = constant 0.0 : f32 + + // CHECK: linalg.fill(%[[I2]], %[[F0]]) : memref, f32 + %r = linalg.fill(%A, %f0) : tensor, f32 -> tensor + + // CHECK: dealloc %[[ALLOC]] : memref + // CHECK: %[[R:.*]] = memref.tensor_load %[[I2]] : memref + // CHECK: return %[[R]] : tensor + return %r: tensor +}