diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h --- a/mlir/include/mlir/Dialect/Linalg/Passes.h +++ b/mlir/include/mlir/Dialect/Linalg/Passes.h @@ -13,6 +13,7 @@ #ifndef MLIR_DIALECT_LINALG_PASSES_H_ #define MLIR_DIALECT_LINALG_PASSES_H_ +#include "mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Pass/Pass.h" @@ -62,6 +63,8 @@ /// on SSA use-def chains starting from function operands that are annotated /// with the 'inplaceable' attribute. std::unique_ptr createLinalgComprehensiveModuleBufferizePass(); +std::unique_ptr createLinalgComprehensiveModuleBufferizePass( + linalg::AllocationCallbacks allocationFns); /// Create a pass to convert Linalg operations which work on tensors to use /// buffers instead. diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -40,7 +40,10 @@ "Only runs inplaceability analysis (for testing purposes only)">, Option<"allowReturnMemref", "allow-return-memref", "bool", /*default=*/"false", - "Allows the return of memrefs (for testing purposes only)"> + "Allows the return of memrefs (for testing purposes only)">, + Option<"useAlloca", "use-alloca", "bool", + /*default=*/"false", + "Use stack allocations for memrefs (for testing purposes only)"> ]; let constructor = "mlir::createLinalgComprehensiveModuleBufferizePass()"; } diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h b/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h @@ -170,14 +170,36 @@ BufferizationAliasInfo &aliasInfo, const DominanceInfo &domInfo); +/// Default allocation function that is used by the comprehensive bufferization +/// pass. The default currently creates a ranked memref using `memref.alloc`. +Optional defaultAllocationFn(OpBuilder &b, Location loc, + Value shapedValue); + +/// Default deallocation function that is used by the comprehensive +/// bufferization pass. It expects to recieve back the value called from the +/// `defaultAllocationFn`. +void defaultDeallocationFn(OpBuilder &b, Location loc, Value allocatedBuffer); + +/// Callback functions that are used by the comprehensive bufferization pass to +/// allocate/deallocate memory. These default to use the +/// `defaultAllocationFn`/`defaultDeallocationFn`, but can be overridden by the +/// caller. The `deallocationFn` is gauranteed to recieve the `Value` returned +/// by the `allocationFn`. +struct AllocationCallbacks { + std::function(OpBuilder &b, Location loc, Value shapedValue)> + allocationFn = defaultAllocationFn; + std::function deallocationFn = + defaultDeallocationFn; +}; + /// Bufferize one particular op. /// `bufferizedFunctionTypes` (resp. `globalCreator`) are expected to be /// non-null if `op` is a CallOpInterface (resp. GlobalCreator). LogicalResult bufferizeOp(Operation *op, BlockAndValueMapping &bvm, BufferizationAliasInfo &aliasInfo, - DenseMap *bufferizedFunctionTypes = nullptr, - GlobalCreator *globalCreator = nullptr); + AllocationCallbacks allocationFns, + DenseMap *bufferizedFunctionTypes = nullptr); } // namespace linalg } // namespace mlir diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -33,6 +33,7 @@ MLIRAnalysis MLIRArithmetic MLIRComplex + MLIRInferTypeOpInterface MLIRIR MLIRMemRef MLIRLinalgAnalysis diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -118,12 +118,12 @@ #include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/IR/AsmState.h" #include "mlir/IR/Operation.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/BufferUtils.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/Passes.h" - #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/SetVector.h" @@ -893,7 +893,6 @@ const DenseSet &usesRead, const DenseSet &usesWrite, const DominanceInfo &domInfo) const { - for (OpOperand *uRead : usesRead) { Operation *readingOp = uRead->getOwner(); @@ -1317,66 +1316,27 @@ /// Create an Allocop/DeAllocOp pair, where the AllocOp is after /// `shapedValue.getDefiningOp` (or at the top of the block in case of a /// bbArg) and the DeallocOp is at the end of the block. -static Value -createNewAllocDeallocPairForShapedValue(OpBuilder &b, Location loc, - Value shapedValue, - BufferizationAliasInfo &aliasInfo) { +static Value createNewAllocDeallocPairForShapedValue( + OpBuilder &b, Location loc, Value shapedValue, + BufferizationAliasInfo &aliasInfo, AllocationCallbacks &allocationFns) { // Take a guard before anything else. OpBuilder::InsertionGuard g(b); - // TODO: non-zero address space. - // TODO: layout information if relevant. - // Cannot allocate an unranked memref so just always go for the contiguous - // form. - MemRefType allocMemRefType = - getContiguousMemRefType(shapedValue.getType().cast()); assert(shapedValue.getType().isa()); MemRefType memRefType = shapedValue.getType().dyn_cast(); - memRefType = memRefType ? memRefType : allocMemRefType; - - if (auto bbArg = shapedValue.dyn_cast()) { - b.setInsertionPointToStart(bbArg.getOwner()); - loc = bbArg.getOwner()->getParentOp()->getLoc(); - } else { - b.setInsertionPoint(shapedValue.getDefiningOp()); - loc = shapedValue.getDefiningOp()->getLoc(); - } - // Compute the dynamic part of the shape. - SmallVector dynShape; - for (auto dim : enumerate(memRefType.getShape())) - if (dim.value() == ShapedType::kDynamicSize) - dynShape.push_back(createOrFoldDimOp(b, loc, shapedValue, dim.index())); - - // If the buffer is statically shaped, try to hoist it to the first enclosing - // parallel region. - // TODO: this concept of parallel region and threadlocal needs interfaces. - // TODO: also hoist in the dynamic case. For now this relies on subsequent - // calls to LICM and buffer hoisting which will most likely not succeed. - // TODO: when packing, allocate a static bounding box which will enable more - // hoisting. - Value allocated; - { // Guarded insertion point to potentially hoist the AllocOp. - OpBuilder::InsertionGuard g(b); - if (dynShape.empty()) { - Operation *parent = - getFirstParentOfType(shapedValue); - if (parent) - b.setInsertionPointToStart(&(parent->getRegion(0).front())); - } - allocated = b.create( - loc, allocMemRefType, dynShape, b.getI64IntegerAttr(kBufferAlignments)); - aliasInfo.createAliasInfoEntry(allocated); - } - Value casted = allocated; - if (memRefType != allocMemRefType) { - casted = b.create(loc, memRefType, allocated); - aliasInfo.insertNewBufferEquivalence(casted, allocated); + Optional allocated = allocationFns.allocationFn(b, loc, shapedValue); + // TODO: For now just assert the value is returned. Eventually need to + // error-propagate. + assert(allocated && "allocation failed"); + Value casted = allocated.getValue(); + MemRefType allocMemRefType = allocated->getType().cast(); + if (memRefType && memRefType != allocMemRefType) { + casted = b.create(loc, memRefType, allocated.getValue()); + aliasInfo.insertNewBufferEquivalence(casted, allocated.getValue()); } - b.setInsertionPoint(allocated.getParentBlock()->getTerminator()); - b.create(loc, allocated); + allocationFns.deallocationFn(b, loc, allocated.getValue()); return casted; } @@ -1390,6 +1350,7 @@ static Value getResultBuffer(OpBuilder &b, OpResult result, const BlockAndValueMapping &bvm, BufferizationAliasInfo &aliasInfo, + AllocationCallbacks allocationFns, bool skipCopy = false) { OpBuilder::InsertionGuard guard(b); Operation *op = result.getOwner(); @@ -1405,8 +1366,8 @@ if (getInPlace(result) != InPlaceSpec::True) { Location loc = op->getLoc(); // Allocate the result buffer. - Value resultBuffer = - createNewAllocDeallocPairForShapedValue(b, loc, operand, aliasInfo); + Value resultBuffer = createNewAllocDeallocPairForShapedValue( + b, loc, operand, aliasInfo, allocationFns); if (!skipCopy && !isInitTensorOp(operand)) { // Set insertion point now that potential alloc/dealloc are introduced. b.setInsertionPoint(op); @@ -1425,7 +1386,8 @@ static void allocateBuffersForResults(OpBuilder &b, Location loc, LinalgOp op, SmallVectorImpl &resultBuffers, BlockAndValueMapping &bvm, - BufferizationAliasInfo &aliasInfo) { + BufferizationAliasInfo &aliasInfo, + AllocationCallbacks &allocationFns) { // Take a guard before anything else. OpBuilder::InsertionGuard g(b); b.setInsertionPoint(op); @@ -1436,7 +1398,8 @@ OpResult opResult = getInplaceableOpResult(*opOperand); assert(opResult && "could not find correspond OpResult"); bool skipCopy = !op.payloadUsesValueFromOperand(opOperand); - Value resultBuffer = getResultBuffer(b, opResult, bvm, aliasInfo, skipCopy); + Value resultBuffer = + getResultBuffer(b, opResult, bvm, aliasInfo, allocationFns, skipCopy); resultBuffers.push_back(resultBuffer); } @@ -1447,7 +1410,8 @@ /// Generic conversion for any LinalgOp on tensors. static LogicalResult bufferize(OpBuilder &b, LinalgOp op, BlockAndValueMapping &bvm, - BufferizationAliasInfo &aliasInfo) { + BufferizationAliasInfo &aliasInfo, + AllocationCallbacks &allocationFns) { // Take a guard before anything else. OpBuilder::InsertionGuard g(b); @@ -1469,7 +1433,8 @@ } SmallVector newOutputBuffers; // Try to allocate new buffers depending on op's inplace semantics. - allocateBuffersForResults(b, loc, op, newOutputBuffers, bvm, aliasInfo); + allocateBuffersForResults(b, loc, op, newOutputBuffers, bvm, aliasInfo, + allocationFns); // Clone the newly bufferized op. SmallVector newOperands = newInputBuffers; @@ -1493,7 +1458,7 @@ /// to allow FuncOp that are inplaceable to write inPlace. static LogicalResult bufferize(OpBuilder &b, CallOpInterface callOp, BlockAndValueMapping &bvm, - BufferizationAliasInfo &aliasInfo, + BufferizationAliasInfo &aliasInfo, AllocationCallbacks &allocationFns, DenseMap &bufferizedFunctionTypes) { FuncOp funcOp = getCalledFunction(callOp); assert(isa(callOp.getOperation()) && funcOp && @@ -1631,12 +1596,14 @@ /// tensor::CastOp bufferizes to memref::CastOp. static LogicalResult bufferize(OpBuilder &b, tensor::CastOp castOp, BlockAndValueMapping &bvm, - BufferizationAliasInfo &aliasInfo) { + BufferizationAliasInfo &aliasInfo, + AllocationCallbacks &allocationFn) { // Take a guard before anything else. OpBuilder::InsertionGuard g(b); b.setInsertionPoint(castOp); - Value resultBuffer = getResultBuffer(b, castOp->getResult(0), bvm, aliasInfo); + Value resultBuffer = + getResultBuffer(b, castOp->getResult(0), bvm, aliasInfo, allocationFn); Type sourceType = resultBuffer.getType(); auto rankedMemRefType = sourceType.dyn_cast(); auto unrankedMemRefType = sourceType.dyn_cast(); @@ -1660,10 +1627,15 @@ static LogicalResult bufferize(OpBuilder &b, arith::ConstantOp constantOp, BlockAndValueMapping &bvm, - BufferizationAliasInfo &aliasInfo, - GlobalCreator &globalCreator) { + BufferizationAliasInfo &aliasInfo) { assert(constantOp.getType().dyn_cast() && "not a constant ranked tensor"); + auto moduleOp = constantOp->getParentOfType(); + if (!moduleOp) { + return constantOp.emitError( + "cannot bufferize constants not within builtin.module op"); + } + GlobalCreator globalCreator(moduleOp); // Take a guard before anything else. OpBuilder::InsertionGuard g(b); @@ -1698,7 +1670,8 @@ static LogicalResult bufferize(OpBuilder &b, scf::ForOp forOp, BlockAndValueMapping &bvm, - BufferizationAliasInfo &aliasInfo) { + BufferizationAliasInfo &aliasInfo, + AllocationCallbacks &allocationFn) { // Take a guard before anything else. OpBuilder::InsertionGuard g(b); @@ -1711,7 +1684,8 @@ "unsupported unranked tensor"); // TODO: More general: Matching bbArg does not bufferize to a read. - Value resultBuffer = getResultBuffer(b, opResult, bvm, aliasInfo); + Value resultBuffer = + getResultBuffer(b, opResult, bvm, aliasInfo, allocationFn); OpOperand &opOperand = forOp.getOpOperandForResult(opResult); BlockArgument bbArg = forOp.getRegionIterArgForOpOperand(opOperand); @@ -1727,7 +1701,8 @@ /// FuncOp always creates TensorToMemRef ops. static LogicalResult bufferize(OpBuilder &b, FuncOp funcOp, BlockAndValueMapping &bvm, - BufferizationAliasInfo &aliasInfo) { + BufferizationAliasInfo &aliasInfo, + AllocationCallbacks &allocationFn) { // Take a guard before anything else. OpBuilder::InsertionGuard g(b); b.setInsertionPointToStart(&funcOp.body().front()); @@ -1753,13 +1728,15 @@ /// TODO: consider hoisting across function boundaries prior to bufferization. static LogicalResult bufferize(OpBuilder &b, InitTensorOp initTensorOp, BlockAndValueMapping &bvm, - BufferizationAliasInfo &aliasInfo) { + BufferizationAliasInfo &aliasInfo, + AllocationCallbacks &allocationFn) { // Take a guard before anything else. OpBuilder::InsertionGuard g(b); b.setInsertionPoint(initTensorOp); Value alloc = createNewAllocDeallocPairForShapedValue( - b, initTensorOp->getLoc(), initTensorOp.result(), aliasInfo); + b, initTensorOp->getLoc(), initTensorOp.result(), aliasInfo, + allocationFn); map(bvm, initTensorOp.result(), alloc); return success(); } @@ -1792,7 +1769,8 @@ /// Bufferization for TiledLoopOp.. static LogicalResult bufferize(OpBuilder &b, TiledLoopOp tiledLoopOp, BlockAndValueMapping &bvm, - BufferizationAliasInfo &aliasInfo) { + BufferizationAliasInfo &aliasInfo, + AllocationCallbacks &allocationFn) { // Take a guard before anything else. OpBuilder::InsertionGuard g(b); @@ -1832,7 +1810,8 @@ const OpResult &opResult = tiledLoopOp->getResult(resultIndex); OpOperand &yieldOperand = yieldOp->getOpOperand(resultIndex); - Value resultBuffer = getResultBuffer(b, opResult, bvm, aliasInfo); + Value resultBuffer = + getResultBuffer(b, opResult, bvm, aliasInfo, allocationFn); // Insert mapping and aliasing info. aliasInfo.createAliasInfoEntry(resultBuffer); @@ -1914,7 +1893,8 @@ /// isolation. static LogicalResult bufferize(OpBuilder &b, ExtractSliceOp extractSliceOp, BlockAndValueMapping &bvm, - BufferizationAliasInfo &aliasInfo) { + BufferizationAliasInfo &aliasInfo, + AllocationCallbacks &allocationFn) { // Take a guard before anything else. OpBuilder::InsertionGuard g(b); @@ -1934,7 +1914,7 @@ auto inPlace = getInPlace(extractSliceOp->getResult(0)); if (inPlace != InPlaceSpec::True) alloc = createNewAllocDeallocPairForShapedValue( - b, loc, extractSliceOp.result(), aliasInfo); + b, loc, extractSliceOp.result(), aliasInfo, allocationFn); // Set insertion point now that potential alloc/dealloc are introduced. b.setInsertionPoint(extractSliceOp); @@ -1964,7 +1944,8 @@ static LogicalResult bufferize(OpBuilder &b, InsertSliceOp insertSliceOp, BlockAndValueMapping &bvm, - BufferizationAliasInfo &aliasInfo) { + BufferizationAliasInfo &aliasInfo, + AllocationCallbacks &allocationFn) { // Take a guard before anything else. OpBuilder::InsertionGuard g(b); b.setInsertionPoint(insertSliceOp); @@ -1979,8 +1960,8 @@ // TODO: be very loud about it or even consider failing the pass. // Alloc a copy for `insertSliceOp.dest()`, it will become the result // buffer. - Value dstMemref = - getResultBuffer(b, insertSliceOp->getResult(0), bvm, aliasInfo); + Value dstMemref = getResultBuffer(b, insertSliceOp->getResult(0), bvm, + aliasInfo, allocationFn); auto dstMemrefType = dstMemref.getType().cast(); Value srcMemref = lookup(bvm, insertSliceOp.source()); @@ -2021,7 +2002,8 @@ static LogicalResult bufferize(OpBuilder &b, VectorTransferOpInterface op, BlockAndValueMapping &bvm, - BufferizationAliasInfo &aliasInfo) { + BufferizationAliasInfo &aliasInfo, + AllocationCallbacks &allocationFn) { // Take a guard before anything else. OpBuilder::InsertionGuard g(b); b.setInsertionPoint(op); @@ -2042,7 +2024,8 @@ // Leave the previous transfer_write to dead code as it still has uses at // this point. auto writeOp = cast(op.getOperation()); - Value resultBuffer = getResultBuffer(b, op->getResult(0), bvm, aliasInfo); + Value resultBuffer = + getResultBuffer(b, op->getResult(0), bvm, aliasInfo, allocationFn); b.create( op.getLoc(), writeOp.vector(), resultBuffer, writeOp.indices(), writeOp.permutation_map(), @@ -2268,18 +2251,96 @@ // Bufferization entry-point for functions. //===----------------------------------------------------------------------===// +static Optional defaultAllocationFnImpl(OpBuilder &b, Location loc, + Value shapedValue, + bool useAlloca = false) { + // Take a guard before anything else. + OpBuilder::InsertionGuard g(b); + + MemRefType allocMemRefType = + getContiguousMemRefType(shapedValue.getType().cast()); + if (auto bbArg = shapedValue.dyn_cast()) { + b.setInsertionPointToStart(bbArg.getOwner()); + loc = bbArg.getOwner()->getParentOp()->getLoc(); + } else { + b.setInsertionPoint(shapedValue.getDefiningOp()); + loc = shapedValue.getDefiningOp()->getLoc(); + } + + // Compute the dynamic part of the shape. + SmallVector dynShape; + bool foundDynamicShapes = false; + if (auto rankedOp = dyn_cast_or_null( + shapedValue.getDefiningOp())) { + ReifiedRankedShapedTypeDims resultDims; + if (succeeded(rankedOp.reifyResultShapes(b, resultDims))) { + foundDynamicShapes = true; + OpResult resultValue = shapedValue.dyn_cast(); + auto &shape = resultDims[resultValue.getResultNumber()]; + for (auto dim : enumerate(allocMemRefType.getShape())) + if (dim.value() == ShapedType::kDynamicSize) + dynShape.push_back(shape[dim.index()]); + } + } + if (!foundDynamicShapes) { + for (auto dim : enumerate(allocMemRefType.getShape())) + if (dim.value() == ShapedType::kDynamicSize) + dynShape.push_back(createOrFoldDimOp(b, loc, shapedValue, dim.index())); + } + + // If the buffer is statically shaped, try to hoist it to the first enclosing + // parallel region. + // TODO: this concept of parallel region and threadlocal needs interfaces. + // TODO: also hoist in the dynamic case. For now this relies on subsequent + // calls to LICM and buffer hoisting which will most likely not succeed. + // TODO: when packing, allocate a static bounding box which will enable more + // hoisting. + if (dynShape.empty()) { + Operation *parent = + getFirstParentOfType(shapedValue); + if (parent) + b.setInsertionPointToStart(&(parent->getRegion(0).front())); + } + Value allocated; + if (useAlloca) { + allocated = b.create( + loc, allocMemRefType, dynShape, b.getI64IntegerAttr(kBufferAlignments)); + } else { + allocated = b.create( + loc, allocMemRefType, dynShape, b.getI64IntegerAttr(kBufferAlignments)); + } + return allocated; +} + +Optional mlir::linalg::defaultAllocationFn(OpBuilder &b, Location loc, + Value shapedValue) { + return defaultAllocationFnImpl(b, loc, shapedValue); +} + +void mlir::linalg::defaultDeallocationFn(OpBuilder &b, Location loc, + Value allocatedBuffer) { + OpBuilder::InsertionGuard g(b); + b.setInsertionPoint(allocatedBuffer.getParentBlock()->getTerminator()); + b.create(loc, allocatedBuffer); +} + LogicalResult mlir::linalg::bufferizeOp( Operation *op, BlockAndValueMapping &bvm, BufferizationAliasInfo &aliasInfo, - DenseMap *bufferizedFunctionTypes, - GlobalCreator *globalCreator) { + AllocationCallbacks allocationFns, + DenseMap *bufferizedFunctionTypes) { OpBuilder b(op->getContext()); return TypeSwitch(op) // Skip BufferCast and TensorLoad ops. .Case( [&](auto) { return success(); }) - .Case( + [&](auto op) { + LDBG("Begin bufferize:\n" << op << '\n'); + return bufferize(b, op, bvm, aliasInfo, allocationFns); + }) + .Case([&](auto op) { LDBG("Begin bufferize:\n" << op << '\n'); return bufferize(b, op, bvm, aliasInfo); @@ -2289,15 +2350,14 @@ if (!bufferizedFunctionTypes) llvm_unreachable( "null bufferizedFunctionTypes when bufferizing CallOpInterface"); - return bufferize(b, op, bvm, aliasInfo, *bufferizedFunctionTypes); + return bufferize(b, op, bvm, aliasInfo, allocationFns, + *bufferizedFunctionTypes); }) .Case([&](arith::ConstantOp op) { if (!isaTensor(op.getResult().getType())) return success(); LDBG("Begin bufferize:\n" << op << '\n'); - if (!globalCreator) - llvm_unreachable("null globalCreator when bufferizing ConstantOp"); - return bufferize(b, op, bvm, aliasInfo, *globalCreator); + return bufferize(b, op, bvm, aliasInfo); }) .Default([&](Operation *op) -> LogicalResult { auto isaTensor = [](Type t) { return t.isa(); }; @@ -2310,14 +2370,13 @@ static LogicalResult bufferizeFuncOpInternals( FuncOp funcOp, BlockAndValueMapping &bvm, BufferizationAliasInfo &aliasInfo, - DenseMap &bufferizedFunctionTypes, - GlobalCreator &globalCreator) { - + AllocationCallbacks &allocationFns, + DenseMap &bufferizedFunctionTypes) { LLVM_DEBUG(llvm::dbgs() << "\n\n"); LDBG("Begin BufferizeFuncOpInternals:\n" << funcOp << '\n'); OpBuilder b(funcOp->getContext()); /// Start by bufferizing `funcOp` arguments. - if (failed(bufferize(b, funcOp, bvm, aliasInfo))) + if (failed(bufferize(b, funcOp, bvm, aliasInfo, allocationFns))) return failure(); // Walk in PreOrder to ensure ops with regions are handled before their body. @@ -2326,8 +2385,8 @@ SmallVector toErase; if (funcOp .walk([&](Operation *op) -> WalkResult { - if (failed(bufferizeOp(op, bvm, aliasInfo, &bufferizedFunctionTypes, - &globalCreator))) + if (failed(bufferizeOp(op, bvm, aliasInfo, allocationFns, + &bufferizedFunctionTypes))) return failure(); // Register post-walk erasure, if necessary. if (isa(op)) @@ -2596,12 +2655,22 @@ struct LinalgComprehensiveModuleBufferize : public LinalgComprehensiveModuleBufferizeBase< LinalgComprehensiveModuleBufferize> { + LinalgComprehensiveModuleBufferize() {} + LinalgComprehensiveModuleBufferize(AllocationCallbacks allocationFns) + : allocationFns( + std::make_unique(std::move(allocationFns))) {} + + LinalgComprehensiveModuleBufferize( + const LinalgComprehensiveModuleBufferize &p) {} void runOnOperation() override; void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); } + +private: + std::unique_ptr allocationFns; }; } // end namespace @@ -2703,6 +2772,19 @@ } void LinalgComprehensiveModuleBufferize::runOnOperation() { + if (!allocationFns) { + if (useAlloca) { + AllocationCallbacks allocaAllocationFns = { + [](OpBuilder &b, Location loc, Value shapedValue) { + return defaultAllocationFnImpl(b, loc, shapedValue, true); + }, + [](OpBuilder &b, Location loc, Value v) {}}; + allocationFns = + std::make_unique(std::move(allocaAllocationFns)); + } else { + allocationFns = std::make_unique(); + } + } ModuleOp moduleOp = getOperation(); applyEnablingTransformations(moduleOp); @@ -2712,7 +2794,6 @@ if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap))) return signalPassFailure(); - GlobalCreator globalCreator(moduleOp); DominanceInfo domInfo(moduleOp); BufferizationAliasInfo aliasInfo(moduleOp); // Interestingly, all function args that are not visible outside of a module @@ -2745,8 +2826,8 @@ if (!testAnalysisOnly) { BlockAndValueMapping tensorToBufferMap; if (failed(bufferizeFuncOpInternals(funcOp, tensorToBufferMap, aliasInfo, - bufferizedFunctionTypes, - globalCreator))) { + *allocationFns, + bufferizedFunctionTypes))) { signalPassFailure(); return; } @@ -2796,3 +2877,7 @@ std::unique_ptr mlir::createLinalgComprehensiveModuleBufferizePass() { return std::make_unique(); } +std::unique_ptr mlir::createLinalgComprehensiveModuleBufferizePass( + AllocationCallbacks allocationFns) { + return std::make_unique(allocationFns); +} diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-alloca.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-alloca.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-alloca.mlir @@ -0,0 +1,65 @@ +// RUN: mlir-opt %s -pass-pipeline="linalg-comprehensive-module-bufferize{allow-return-memref use-alloca}" -split-input-file | FileCheck %s + +// CHECK-DAG: #[[$DYN_0D_MAP:.*]] = affine_map<()[s0] -> (s0)> +// CHECK-DAG: #[[$DYN_1D_MAP:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)> + +// CHECK: func @init_and_dot( +// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: memref<64xf32, #[[$DYN_1D_MAP]]> +// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: memref<64xf32, #[[$DYN_1D_MAP]]> +// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: memref +func @init_and_dot(%a: tensor<64xf32>, %b: tensor<64xf32>, %c: tensor) -> tensor { + // CHECK-NEXT: %[[C0:.*]] = arith.constant 0{{.*}} : f32 + %v0 = arith.constant 0.0 : f32 + + // CHECK-NEXT: linalg.fill(%[[C0]], %[[C]]) : f32, memref + %d = linalg.fill(%v0, %c) : f32, tensor -> tensor + + // CHECK-NEXT: linalg.dot ins(%[[A]], %[[B]] : memref<64xf32, #[[$DYN_1D_MAP]]>, memref<64xf32, #[[$DYN_1D_MAP]]>) outs(%[[C]] : memref) + %e = linalg.dot ins(%a, %b : tensor<64xf32>,tensor<64xf32>) + outs(%d: tensor) -> tensor + + // CHECK-NEXT: return + return %e : tensor +} + +// CHECK: func @main() +func @main() { + // CHECK-DAG: %[[C0:.*]] = arith.constant 0{{.*}} : f32 + // CHECK-DAG: %[[C1:.*]] = arith.constant 1{{.*}} : f32 + // CHECK-DAG: %[[C2:.*]] = arith.constant 2{{.*}} : f32 + %v0 = arith.constant 0.0 : f32 + %v1 = arith.constant 1.0 : f32 + %v2 = arith.constant 2.0 : f32 + + // CHECK-NEXT: %[[C:.*]] = memref.alloca() {alignment = 128 : i64} : memref + // CHECK-NEXT: %[[B:.*]] = memref.alloca() {alignment = 128 : i64} : memref<64xf32> + // CHECK-NEXT: %[[A:.*]] = memref.alloca() {alignment = 128 : i64} : memref<64xf32> + %A = linalg.init_tensor [64] : tensor<64xf32> + %B = linalg.init_tensor [64] : tensor<64xf32> + %C = linalg.init_tensor [] : tensor + + // CHECK-NEXT: linalg.fill(%[[C1]], %[[A]]) : f32, memref<64xf32> + // CHECK-NEXT: linalg.fill(%[[C2]], %[[B]]) : f32, memref<64xf32> + // CHECK-NEXT: linalg.fill(%[[C0]], %[[C]]) : f32, memref + %AA = linalg.fill(%v1, %A) : f32, tensor<64xf32> -> tensor<64xf32> + %BB = linalg.fill(%v2, %B) : f32, tensor<64xf32> -> tensor<64xf32> + %CC = linalg.fill(%v0, %C) : f32, tensor -> tensor + + // CHECK-NEXT: %[[cA:.*]] = memref.cast %[[A]] : memref<64xf32> to memref<64xf32, #[[$DYN_1D_MAP]]> + // CHECK-NEXT: %[[cB:.*]] = memref.cast %[[B]] : memref<64xf32> to memref<64xf32, #[[$DYN_1D_MAP]]> + // CHECK-NEXT: %[[cC:.*]] = memref.cast %[[C]] : memref to memref + // CHECK-NEXT: call @init_and_dot(%[[cA]], %[[cB]], %[[cC]]) + %res = call @init_and_dot(%AA, %BB, %CC) : + (tensor<64xf32>, tensor<64xf32>, tensor) -> tensor + + // CHECK-NEXT: %[[dC:.*]] = memref.cast %[[C]] : memref to memref<*xf32> + %res2 = tensor.cast %res: tensor to tensor<*xf32> + + // CHECK-NEXT: call @print_memref_f32(%[[dC]]) : (memref<*xf32>) -> () + call @print_memref_f32(%res2) : (tensor<*xf32>) -> () + + return +} + +// CHECK: func private @print_memref_f32(memref<*xf32>) +func private @print_memref_f32(tensor<*xf32>) diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -6297,6 +6297,7 @@ ":ComplexDialect", ":DialectUtils", ":IR", + ":InferTypeOpInterface", ":LinalgOps", ":LinalgPassIncGen", ":LinalgStructuredOpsIncGen",