diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -40,7 +40,10 @@ "Only runs inplaceability analysis (for testing purposes only)">, Option<"allowReturnMemref", "allow-return-memref", "bool", /*default=*/"false", - "Allows the return of memrefs (for testing purposes only)"> + "Allows the return of memrefs (for testing purposes only)">, + Option<"useAlloca", "use-alloca", "bool", + /*default=*/"false", + "Use stack allocations for memrefs (for testing purposes only)"> ]; let constructor = "mlir::createLinalgComprehensiveModuleBufferizePass()"; } diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h b/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h @@ -175,14 +175,36 @@ BufferizationAliasInfo &aliasInfo, const DominanceInfo &domInfo); +/// Default allocation function that is used by the comprehensive bufferization +/// pass. The default currently creates a ranked memref using `memref.alloc`. +Optional defaultAllocationFn(OpBuilder &b, Location loc, + Value shapedValue); + +/// Default deallocation function that is used by the comprehensive +/// bufferization pass. It expects to recieve back the value called from the +/// `defaultAllocationFn`. +void defaultDeallocationFn(OpBuilder &b, Location loc, Value allocatedBuffer); + +/// Callback functions that are used by the comprehensive bufferization pass to +/// allocate/deallocate memory. These default to use the +/// `defaultAllocationFn`/`defaultDeallocationFn`, but can be overridden by the +/// caller. The `deallocationFn` is gauranteed to recieve the `Value` returned +/// by the `allocationFn`. +struct AllocationCallbacks { + std::function(OpBuilder &b, Location loc, Value shapedValue)> + allocationFn = defaultAllocationFn; + std::function deallocationFn = + defaultDeallocationFn; +}; + /// Bufferize one particular op. /// `bufferizedFunctionTypes` (resp. `globalCreator`) are expected to be /// non-null if `op` is a CallOpInterface (resp. GlobalCreator). LogicalResult bufferizeOp(Operation *op, BlockAndValueMapping &bvm, BufferizationAliasInfo &aliasInfo, - DenseMap *bufferizedFunctionTypes = nullptr, - GlobalCreator *globalCreator = nullptr); + AllocationCallbacks allocationFns, + DenseMap *bufferizedFunctionTypes = nullptr); } // namespace linalg } // namespace mlir diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -33,6 +33,7 @@ MLIRAnalysis MLIRArithmetic MLIRComplex + MLIRInferTypeOpInterface MLIRIR MLIRMemRef MLIRLinalgAnalysis diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -118,12 +118,12 @@ #include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/IR/AsmState.h" #include "mlir/IR/Operation.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/BufferUtils.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/Passes.h" - #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/SetVector.h" @@ -983,7 +983,6 @@ const DenseSet &usesRead, const DenseSet &usesWrite, const DominanceInfo &domInfo) const { - for (OpOperand *uRead : usesRead) { Operation *readingOp = uRead->getOwner(); @@ -1415,66 +1414,27 @@ /// Create an Allocop/DeAllocOp pair, where the AllocOp is after /// `shapedValue.getDefiningOp` (or at the top of the block in case of a /// bbArg) and the DeallocOp is at the end of the block. -static Value -createNewAllocDeallocPairForShapedValue(OpBuilder &b, Location loc, - Value shapedValue, - BufferizationAliasInfo &aliasInfo) { +static Value createNewAllocDeallocPairForShapedValue( + OpBuilder &b, Location loc, Value shapedValue, + BufferizationAliasInfo &aliasInfo, AllocationCallbacks &allocationFns) { // Take a guard before anything else. OpBuilder::InsertionGuard g(b); - // TODO: non-zero address space. - // TODO: layout information if relevant. - // Cannot allocate an unranked memref so just always go for the contiguous - // form. - MemRefType allocMemRefType = - getContiguousMemRefType(shapedValue.getType().cast()); assert(shapedValue.getType().isa()); MemRefType memRefType = shapedValue.getType().dyn_cast(); - memRefType = memRefType ? memRefType : allocMemRefType; - if (auto bbArg = shapedValue.dyn_cast()) { - b.setInsertionPointToStart(bbArg.getOwner()); - loc = bbArg.getOwner()->getParentOp()->getLoc(); - } else { - b.setInsertionPoint(shapedValue.getDefiningOp()); - loc = shapedValue.getDefiningOp()->getLoc(); + Optional allocated = allocationFns.allocationFn(b, loc, shapedValue); + // TODO: For now just assert the value is returned. Eventually need to + // error-propagate. + assert(allocated && "allocation failed"); + Value casted = allocated.getValue(); + MemRefType allocMemRefType = allocated->getType().cast(); + if (memRefType && memRefType != allocMemRefType) { + casted = b.create(loc, memRefType, allocated.getValue()); + aliasInfo.insertNewBufferEquivalence(casted, allocated.getValue()); } - // Compute the dynamic part of the shape. - SmallVector dynShape; - for (auto dim : enumerate(memRefType.getShape())) - if (dim.value() == ShapedType::kDynamicSize) - dynShape.push_back(createOrFoldDimOp(b, loc, shapedValue, dim.index())); - - // If the buffer is statically shaped, try to hoist it to the first enclosing - // parallel region. - // TODO: this concept of parallel region and threadlocal needs interfaces. - // TODO: also hoist in the dynamic case. For now this relies on subsequent - // calls to LICM and buffer hoisting which will most likely not succeed. - // TODO: when packing, allocate a static bounding box which will enable more - // hoisting. - Value allocated; - { // Guarded insertion point to potentially hoist the AllocOp. - OpBuilder::InsertionGuard g(b); - if (dynShape.empty()) { - Operation *parent = - getFirstParentOfType(shapedValue); - if (parent) - b.setInsertionPointToStart(&(parent->getRegion(0).front())); - } - allocated = b.create( - loc, allocMemRefType, dynShape, b.getI64IntegerAttr(kBufferAlignments)); - aliasInfo.createAliasInfoEntry(allocated); - } - Value casted = allocated; - if (memRefType != allocMemRefType) { - casted = b.create(loc, memRefType, allocated); - aliasInfo.insertNewBufferEquivalence(casted, allocated); - } - b.setInsertionPoint(allocated.getParentBlock()->getTerminator()); - b.create(loc, allocated); - + allocationFns.deallocationFn(b, loc, allocated.getValue()); return casted; } @@ -1488,6 +1448,7 @@ static Value getResultBuffer(OpBuilder &b, OpResult result, const BlockAndValueMapping &bvm, BufferizationAliasInfo &aliasInfo, + AllocationCallbacks allocationFns, bool skipCopy = false) { OpBuilder::InsertionGuard guard(b); Operation *op = result.getOwner(); @@ -1515,8 +1476,8 @@ "ops with multiple aliasing OpOperands cannot bufferize out-of-place"); Location loc = op->getLoc(); // Allocate the result buffer. - Value resultBuffer = - createNewAllocDeallocPairForShapedValue(b, loc, operand, aliasInfo); + Value resultBuffer = createNewAllocDeallocPairForShapedValue( + b, loc, operand, aliasInfo, allocationFns); // Do not copy the result of an InitTensorOp. if (isInitTensorOp(operand)) skipCopy = true; @@ -1538,11 +1499,10 @@ /// Helper function for LinalgOp bufferization. /// When allocating a new buffer, analyze whether `op` wants to read form that /// buffer. Only in that case, a copy of the result buffer may be needed. -static LogicalResult -allocateBuffersForResults(OpBuilder &b, Location loc, LinalgOp op, - SmallVectorImpl &resultBuffers, - BlockAndValueMapping &bvm, - BufferizationAliasInfo &aliasInfo) { +static LogicalResult allocateBuffersForResults( + OpBuilder &b, Location loc, LinalgOp op, + SmallVectorImpl &resultBuffers, BlockAndValueMapping &bvm, + BufferizationAliasInfo &aliasInfo, AllocationCallbacks &allocationFns) { // Take a guard before anything else. OpBuilder::InsertionGuard g(b); b.setInsertionPoint(op); @@ -1553,7 +1513,8 @@ OpResult opResult = getInplaceableOpResult(*opOperand); assert(opResult && "could not find correspond OpResult"); bool skipCopy = !op.payloadUsesValueFromOperand(opOperand); - Value resultBuffer = getResultBuffer(b, opResult, bvm, aliasInfo, skipCopy); + Value resultBuffer = + getResultBuffer(b, opResult, bvm, aliasInfo, allocationFns, skipCopy); if (!resultBuffer) return failure(); resultBuffers.push_back(resultBuffer); @@ -1568,7 +1529,8 @@ /// Generic conversion for any LinalgOp on tensors. static LogicalResult bufferize(OpBuilder &b, LinalgOp op, BlockAndValueMapping &bvm, - BufferizationAliasInfo &aliasInfo) { + BufferizationAliasInfo &aliasInfo, + AllocationCallbacks &allocationFns) { // Take a guard before anything else. OpBuilder::InsertionGuard g(b); @@ -1591,7 +1553,7 @@ SmallVector newOutputBuffers; // Try to allocate new buffers depending on op's inplace semantics. if (failed(allocateBuffersForResults(b, loc, op, newOutputBuffers, bvm, - aliasInfo))) + aliasInfo, allocationFns))) return failure(); // Clone the newly bufferized op. @@ -1616,7 +1578,7 @@ /// to allow FuncOp that are inplaceable to write inPlace. static LogicalResult bufferize(OpBuilder &b, CallOpInterface callOp, BlockAndValueMapping &bvm, - BufferizationAliasInfo &aliasInfo, + BufferizationAliasInfo &aliasInfo, AllocationCallbacks &allocationFns, DenseMap &bufferizedFunctionTypes) { FuncOp funcOp = getCalledFunction(callOp); assert(isa(callOp.getOperation()) && funcOp && @@ -1755,12 +1717,14 @@ /// tensor::CastOp bufferizes to memref::CastOp. static LogicalResult bufferize(OpBuilder &b, tensor::CastOp castOp, BlockAndValueMapping &bvm, - BufferizationAliasInfo &aliasInfo) { + BufferizationAliasInfo &aliasInfo, + AllocationCallbacks &allocationFn) { // Take a guard before anything else. OpBuilder::InsertionGuard g(b); b.setInsertionPoint(castOp); - Value resultBuffer = getResultBuffer(b, castOp->getResult(0), bvm, aliasInfo); + Value resultBuffer = + getResultBuffer(b, castOp->getResult(0), bvm, aliasInfo, allocationFn); if (!resultBuffer) return failure(); Type sourceType = resultBuffer.getType(); @@ -1786,10 +1750,15 @@ static LogicalResult bufferize(OpBuilder &b, arith::ConstantOp constantOp, BlockAndValueMapping &bvm, - BufferizationAliasInfo &aliasInfo, - GlobalCreator &globalCreator) { + BufferizationAliasInfo &aliasInfo) { assert(constantOp.getType().dyn_cast() && "not a constant ranked tensor"); + auto moduleOp = constantOp->getParentOfType(); + if (!moduleOp) { + return constantOp.emitError( + "cannot bufferize constants not within builtin.module op"); + } + GlobalCreator globalCreator(moduleOp); // Take a guard before anything else. OpBuilder::InsertionGuard g(b); @@ -1824,7 +1793,8 @@ static LogicalResult bufferize(OpBuilder &b, scf::ForOp forOp, BlockAndValueMapping &bvm, - BufferizationAliasInfo &aliasInfo) { + BufferizationAliasInfo &aliasInfo, + AllocationCallbacks &allocationFn) { // Take a guard before anything else. OpBuilder::InsertionGuard g(b); @@ -1837,7 +1807,8 @@ "unsupported unranked tensor"); // TODO: More general: Matching bbArg does not bufferize to a read. - Value resultBuffer = getResultBuffer(b, opResult, bvm, aliasInfo); + Value resultBuffer = + getResultBuffer(b, opResult, bvm, aliasInfo, allocationFn); if (!resultBuffer) return failure(); @@ -1854,7 +1825,8 @@ static LogicalResult bufferize(OpBuilder &b, scf::IfOp ifOp, BlockAndValueMapping &bvm, - BufferizationAliasInfo &aliasInfo) { + BufferizationAliasInfo &aliasInfo, + AllocationCallbacks &allocationFn) { // Take a guard before anything else. OpBuilder::InsertionGuard g(b); @@ -1866,7 +1838,8 @@ assert(opResult.getType().isa() && "unsupported unranked tensor"); - Value resultBuffer = getResultBuffer(b, opResult, bvm, aliasInfo); + Value resultBuffer = + getResultBuffer(b, opResult, bvm, aliasInfo, allocationFn); if (!resultBuffer) return failure(); @@ -1880,7 +1853,8 @@ /// FuncOp always creates TensorToMemRef ops. static LogicalResult bufferize(OpBuilder &b, FuncOp funcOp, BlockAndValueMapping &bvm, - BufferizationAliasInfo &aliasInfo) { + BufferizationAliasInfo &aliasInfo, + AllocationCallbacks &allocationFn) { // Take a guard before anything else. OpBuilder::InsertionGuard g(b); b.setInsertionPointToStart(&funcOp.body().front()); @@ -1906,7 +1880,8 @@ /// TODO: consider hoisting across function boundaries prior to bufferization. static LogicalResult bufferize(OpBuilder &b, InitTensorOp initTensorOp, BlockAndValueMapping &bvm, - BufferizationAliasInfo &aliasInfo) { + BufferizationAliasInfo &aliasInfo, + AllocationCallbacks &allocationFn) { // The InitTensorOp may have been eliminated. if (initTensorOp->getUses().empty()) return success(); @@ -1916,7 +1891,8 @@ b.setInsertionPoint(initTensorOp); Value alloc = createNewAllocDeallocPairForShapedValue( - b, initTensorOp->getLoc(), initTensorOp.result(), aliasInfo); + b, initTensorOp->getLoc(), initTensorOp.result(), aliasInfo, + allocationFn); map(bvm, initTensorOp.result(), alloc); return success(); } @@ -1949,7 +1925,8 @@ /// Bufferization for TiledLoopOp.. static LogicalResult bufferize(OpBuilder &b, TiledLoopOp tiledLoopOp, BlockAndValueMapping &bvm, - BufferizationAliasInfo &aliasInfo) { + BufferizationAliasInfo &aliasInfo, + AllocationCallbacks &allocationFn) { // Take a guard before anything else. OpBuilder::InsertionGuard g(b); @@ -1989,7 +1966,8 @@ const OpResult &opResult = tiledLoopOp->getResult(resultIndex); OpOperand &yieldOperand = yieldOp->getOpOperand(resultIndex); - Value resultBuffer = getResultBuffer(b, opResult, bvm, aliasInfo); + Value resultBuffer = + getResultBuffer(b, opResult, bvm, aliasInfo, allocationFn); if (!resultBuffer) return failure(); @@ -2073,7 +2051,8 @@ /// isolation. static LogicalResult bufferize(OpBuilder &b, ExtractSliceOp extractSliceOp, BlockAndValueMapping &bvm, - BufferizationAliasInfo &aliasInfo) { + BufferizationAliasInfo &aliasInfo, + AllocationCallbacks &allocationFn) { // Take a guard before anything else. OpBuilder::InsertionGuard g(b); @@ -2093,7 +2072,7 @@ auto inPlace = getInPlace(extractSliceOp->getResult(0)); if (inPlace != InPlaceSpec::True) alloc = createNewAllocDeallocPairForShapedValue( - b, loc, extractSliceOp.result(), aliasInfo); + b, loc, extractSliceOp.result(), aliasInfo, allocationFn); // Set insertion point now that potential alloc/dealloc are introduced. b.setInsertionPoint(extractSliceOp); @@ -2125,7 +2104,8 @@ static LogicalResult bufferize(OpBuilder &b, InsertSliceOp insertSliceOp, BlockAndValueMapping &bvm, - BufferizationAliasInfo &aliasInfo) { + BufferizationAliasInfo &aliasInfo, + AllocationCallbacks &allocationFn) { // Take a guard before anything else. OpBuilder::InsertionGuard g(b); b.setInsertionPoint(insertSliceOp); @@ -2140,8 +2120,8 @@ // TODO: be very loud about it or even consider failing the pass. // Alloc a copy for `insertSliceOp.dest()`, it will become the result // buffer. - Value dstMemref = - getResultBuffer(b, insertSliceOp->getResult(0), bvm, aliasInfo); + Value dstMemref = getResultBuffer(b, insertSliceOp->getResult(0), bvm, + aliasInfo, allocationFn); if (!dstMemref) return failure(); auto dstMemrefType = dstMemref.getType().cast(); @@ -2184,7 +2164,8 @@ static LogicalResult bufferize(OpBuilder &b, VectorTransferOpInterface op, BlockAndValueMapping &bvm, - BufferizationAliasInfo &aliasInfo) { + BufferizationAliasInfo &aliasInfo, + AllocationCallbacks &allocationFn) { // Take a guard before anything else. OpBuilder::InsertionGuard g(b); b.setInsertionPoint(op); @@ -2205,7 +2186,8 @@ // Leave the previous transfer_write to dead code as it still has uses at // this point. auto writeOp = cast(op.getOperation()); - Value resultBuffer = getResultBuffer(b, op->getResult(0), bvm, aliasInfo); + Value resultBuffer = + getResultBuffer(b, op->getResult(0), bvm, aliasInfo, allocationFn); if (!resultBuffer) return failure(); b.create( @@ -2436,43 +2418,124 @@ // Bufferization entry-point for functions. //===----------------------------------------------------------------------===// +/// Compute the type of the `memref` to use for allocating the buffer for +/// `shapedValue`. Also returns (by reference in `dynShape`), the value for the +/// dynamic dimensions in the returned `memref` type. The function also sets the +/// insertion point of the builder `b` to the position where the allocation is +/// to be inserted. +static MemRefType getAllocationTypeAndShape(OpBuilder &b, Location loc, + Value shapedValue, + SmallVectorImpl &dynShape) { + MemRefType allocMemRefType = + getContiguousMemRefType(shapedValue.getType().cast()); + if (auto bbArg = shapedValue.dyn_cast()) { + b.setInsertionPointToStart(bbArg.getOwner()); + loc = bbArg.getOwner()->getParentOp()->getLoc(); + } else { + b.setInsertionPoint(shapedValue.getDefiningOp()); + loc = shapedValue.getDefiningOp()->getLoc(); + } + + // Compute the dynamic part of the shape. + bool foundDynamicShapes = false; + if (auto rankedOp = dyn_cast_or_null( + shapedValue.getDefiningOp())) { + ReifiedRankedShapedTypeDims resultDims; + if (succeeded(rankedOp.reifyResultShapes(b, resultDims))) { + foundDynamicShapes = true; + OpResult resultValue = shapedValue.dyn_cast(); + auto &shape = resultDims[resultValue.getResultNumber()]; + for (auto dim : enumerate(allocMemRefType.getShape())) + if (dim.value() == ShapedType::kDynamicSize) + dynShape.push_back(shape[dim.index()]); + } + } + if (!foundDynamicShapes) { + for (auto dim : enumerate(allocMemRefType.getShape())) + if (dim.value() == ShapedType::kDynamicSize) + dynShape.push_back(createOrFoldDimOp(b, loc, shapedValue, dim.index())); + } + + // If the buffer is statically shaped, try to hoist it to the first enclosing + // parallel region. + // TODO: this concept of parallel region and threadlocal needs interfaces. + // TODO: also hoist in the dynamic case. For now this relies on subsequent + // calls to LICM and buffer hoisting which will most likely not succeed. + // TODO: when packing, allocate a static bounding box which will enable more + // hoisting. + if (dynShape.empty()) { + Operation *parent = + getFirstParentOfType(shapedValue); + if (parent) + b.setInsertionPointToStart(&(parent->getRegion(0).front())); + } + return allocMemRefType; +} + +Optional mlir::linalg::defaultAllocationFn(OpBuilder &b, Location loc, + Value shapedValue) { + // Take a guard before anything else. + OpBuilder::InsertionGuard g(b); + SmallVector dynShape; + MemRefType allocMemRefType = + getAllocationTypeAndShape(b, loc, shapedValue, dynShape); + Value allocated = b.create( + loc, allocMemRefType, dynShape, b.getI64IntegerAttr(kBufferAlignments)); + return allocated; +} + +static Optional allocationFnUsingAlloca(OpBuilder &b, Location loc, + Value shapedValue) { + OpBuilder::InsertionGuard g(b); + SmallVector dynShape; + MemRefType allocMemRefType = + getAllocationTypeAndShape(b, loc, shapedValue, dynShape); + Value allocated = b.create( + loc, allocMemRefType, dynShape, b.getI64IntegerAttr(kBufferAlignments)); + return allocated; +} + +void mlir::linalg::defaultDeallocationFn(OpBuilder &b, Location loc, + Value allocatedBuffer) { + OpBuilder::InsertionGuard g(b); + b.setInsertionPoint(allocatedBuffer.getParentBlock()->getTerminator()); + b.create(loc, allocatedBuffer); +} + LogicalResult mlir::linalg::bufferizeOp( Operation *op, BlockAndValueMapping &bvm, BufferizationAliasInfo &aliasInfo, - DenseMap *bufferizedFunctionTypes, - GlobalCreator *globalCreator) { + AllocationCallbacks allocationFns, + DenseMap *bufferizedFunctionTypes) { OpBuilder b(op->getContext()); return TypeSwitch(op) // Skip BufferCast and TensorLoad ops. .Case( [&](auto) { return success(); }) - .Case( + [&](auto op) { + LDBG("Begin bufferize:\n" << op << '\n'); + return bufferize(b, op, bvm, aliasInfo, allocationFns); + }) + .Case([&](auto op) { LDBG("Begin bufferize:\n" << op << '\n'); return bufferize(b, op, bvm, aliasInfo); }) - .Case([&](auto op) { - LDBG("Begin bufferize:\n" << op << '\n'); - return bufferize(b, op, bvm, aliasInfo); - }) .Case([&](CallOpInterface op) { LDBG("Begin bufferize:\n" << op << '\n'); if (!bufferizedFunctionTypes) llvm_unreachable( "null bufferizedFunctionTypes when bufferizing CallOpInterface"); - return bufferize(b, op, bvm, aliasInfo, *bufferizedFunctionTypes); + return bufferize(b, op, bvm, aliasInfo, allocationFns, + *bufferizedFunctionTypes); }) .Case([&](arith::ConstantOp op) { if (!isaTensor(op.getResult().getType())) return success(); LDBG("Begin bufferize:\n" << op << '\n'); - if (!globalCreator) - llvm_unreachable("null globalCreator when bufferizing ConstantOp"); - return bufferize(b, op, bvm, aliasInfo, *globalCreator); + return bufferize(b, op, bvm, aliasInfo); }) .Default([&](Operation *op) -> LogicalResult { auto isaTensor = [](Type t) { return t.isa(); }; @@ -2485,15 +2548,13 @@ static LogicalResult bufferizeFuncOpInternals( FuncOp funcOp, BlockAndValueMapping &bvm, BufferizationAliasInfo &aliasInfo, - DenseMap &bufferizedFunctionTypes, - GlobalCreator &globalCreator) { - + AllocationCallbacks &allocationFns, + DenseMap &bufferizedFunctionTypes) { LLVM_DEBUG(llvm::dbgs() << "\n\n"); LDBG("Begin BufferizeFuncOpInternals:\n" << funcOp << '\n'); OpBuilder b(funcOp->getContext()); - - // Start by bufferizing `funcOp` arguments. - if (failed(bufferize(b, funcOp, bvm, aliasInfo))) + /// Start by bufferizing `funcOp` arguments. + if (failed(bufferize(b, funcOp, bvm, aliasInfo, allocationFns))) return failure(); // Cannot erase ops during the traversal. Do that afterwards. @@ -2516,13 +2577,13 @@ } for (Operation *op : llvm::reverse(preorderBufferize)) - if (failed(bufferizeOp(op, bvm, aliasInfo, &bufferizedFunctionTypes, - &globalCreator))) + if (failed(bufferizeOp(op, bvm, aliasInfo, allocationFns, + &bufferizedFunctionTypes))) return failure(); if (!bufferizedOps.contains(op) && - failed(bufferizeOp(op, bvm, aliasInfo, &bufferizedFunctionTypes, - &globalCreator))) + failed(bufferizeOp(op, bvm, aliasInfo, allocationFns, + &bufferizedFunctionTypes))) return failure(); // Register post-walk erasure, if necessary. @@ -2793,12 +2854,19 @@ struct LinalgComprehensiveModuleBufferize : public LinalgComprehensiveModuleBufferizeBase< LinalgComprehensiveModuleBufferize> { + LinalgComprehensiveModuleBufferize() {} + + LinalgComprehensiveModuleBufferize( + const LinalgComprehensiveModuleBufferize &p) {} void runOnOperation() override; void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); } + +private: + std::unique_ptr allocationFns; }; } // end namespace @@ -2983,6 +3051,22 @@ } void LinalgComprehensiveModuleBufferize::runOnOperation() { + if (!allocationFns) { + // The allocation functions to use needs to be set here. The flag for the + // pass and flag for the use of alloca map to LLVM command line + // options. These being static global objects have no set order in which + // they are defined. So ideally this should be in the constructor, but the + // constructor might be called before the flag is initialized using the + // command line option. So this is set up at the start of the pass. + if (useAlloca) { + AllocationCallbacks allocaAllocationFns = { + allocationFnUsingAlloca, [](OpBuilder &b, Location loc, Value v) {}}; + allocationFns = + std::make_unique(std::move(allocaAllocationFns)); + } else { + allocationFns = std::make_unique(); + } + } ModuleOp moduleOp = getOperation(); applyEnablingTransformations(moduleOp); @@ -2992,7 +3076,6 @@ if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap))) return signalPassFailure(); - GlobalCreator globalCreator(moduleOp); DominanceInfo domInfo(moduleOp); BufferizationAliasInfo aliasInfo(moduleOp); // Interestingly, all function args that are not visible outside of a module @@ -3032,8 +3115,8 @@ if (!testAnalysisOnly) { BlockAndValueMapping tensorToBufferMap; if (failed(bufferizeFuncOpInternals(funcOp, tensorToBufferMap, aliasInfo, - bufferizedFunctionTypes, - globalCreator))) { + *allocationFns, + bufferizedFunctionTypes))) { signalPassFailure(); return; } diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-alloca.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-alloca.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-alloca.mlir @@ -0,0 +1,65 @@ +// RUN: mlir-opt %s -pass-pipeline="linalg-comprehensive-module-bufferize{allow-return-memref use-alloca}" -split-input-file | FileCheck %s + +// CHECK-DAG: #[[$DYN_0D_MAP:.*]] = affine_map<()[s0] -> (s0)> +// CHECK-DAG: #[[$DYN_1D_MAP:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)> + +// CHECK: func @init_and_dot( +// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: memref<64xf32, #[[$DYN_1D_MAP]]> +// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: memref<64xf32, #[[$DYN_1D_MAP]]> +// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: memref +func @init_and_dot(%a: tensor<64xf32>, %b: tensor<64xf32>, %c: tensor) -> tensor { + // CHECK-NEXT: %[[C0:.*]] = arith.constant 0{{.*}} : f32 + %v0 = arith.constant 0.0 : f32 + + // CHECK-NEXT: linalg.fill(%[[C0]], %[[C]]) : f32, memref + %d = linalg.fill(%v0, %c) : f32, tensor -> tensor + + // CHECK-NEXT: linalg.dot ins(%[[A]], %[[B]] : memref<64xf32, #[[$DYN_1D_MAP]]>, memref<64xf32, #[[$DYN_1D_MAP]]>) outs(%[[C]] : memref) + %e = linalg.dot ins(%a, %b : tensor<64xf32>,tensor<64xf32>) + outs(%d: tensor) -> tensor + + // CHECK-NEXT: return + return %e : tensor +} + +// CHECK: func @main() +func @main() { + // CHECK-DAG: %[[C0:.*]] = arith.constant 0{{.*}} : f32 + // CHECK-DAG: %[[C1:.*]] = arith.constant 1{{.*}} : f32 + // CHECK-DAG: %[[C2:.*]] = arith.constant 2{{.*}} : f32 + %v0 = arith.constant 0.0 : f32 + %v1 = arith.constant 1.0 : f32 + %v2 = arith.constant 2.0 : f32 + + // CHECK-NEXT: %[[C:.*]] = memref.alloca() {alignment = 128 : i64} : memref + // CHECK-NEXT: %[[B:.*]] = memref.alloca() {alignment = 128 : i64} : memref<64xf32> + // CHECK-NEXT: %[[A:.*]] = memref.alloca() {alignment = 128 : i64} : memref<64xf32> + %A = linalg.init_tensor [64] : tensor<64xf32> + %B = linalg.init_tensor [64] : tensor<64xf32> + %C = linalg.init_tensor [] : tensor + + // CHECK-NEXT: linalg.fill(%[[C1]], %[[A]]) : f32, memref<64xf32> + // CHECK-NEXT: linalg.fill(%[[C2]], %[[B]]) : f32, memref<64xf32> + // CHECK-NEXT: linalg.fill(%[[C0]], %[[C]]) : f32, memref + %AA = linalg.fill(%v1, %A) : f32, tensor<64xf32> -> tensor<64xf32> + %BB = linalg.fill(%v2, %B) : f32, tensor<64xf32> -> tensor<64xf32> + %CC = linalg.fill(%v0, %C) : f32, tensor -> tensor + + // CHECK-NEXT: %[[cA:.*]] = memref.cast %[[A]] : memref<64xf32> to memref<64xf32, #[[$DYN_1D_MAP]]> + // CHECK-NEXT: %[[cB:.*]] = memref.cast %[[B]] : memref<64xf32> to memref<64xf32, #[[$DYN_1D_MAP]]> + // CHECK-NEXT: %[[cC:.*]] = memref.cast %[[C]] : memref to memref + // CHECK-NEXT: call @init_and_dot(%[[cA]], %[[cB]], %[[cC]]) + %res = call @init_and_dot(%AA, %BB, %CC) : + (tensor<64xf32>, tensor<64xf32>, tensor) -> tensor + + // CHECK-NEXT: %[[dC:.*]] = memref.cast %[[C]] : memref to memref<*xf32> + %res2 = tensor.cast %res: tensor to tensor<*xf32> + + // CHECK-NEXT: call @print_memref_f32(%[[dC]]) : (memref<*xf32>) -> () + call @print_memref_f32(%res2) : (tensor<*xf32>) -> () + + return +} + +// CHECK: func private @print_memref_f32(memref<*xf32>) +func private @print_memref_f32(tensor<*xf32>) diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -6314,6 +6314,7 @@ ":ComplexDialect", ":DialectUtils", ":IR", + ":InferTypeOpInterface", ":LinalgOps", ":LinalgPassIncGen", ":LinalgStructuredOpsIncGen",