diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h @@ -29,7 +29,10 @@ // TODO: from some HW description. static constexpr int64_t kBufferAlignments = 128; -struct BufferizationState; +class BufferizationAliasInfo; +struct BufferizationOptions; +class BufferizationState; +struct PostAnalysisStep; /// Callback functions that are used to allocate/deallocate/copy memory buffers. /// Comprehensive Bufferize provides default implementations of these functions. @@ -68,6 +71,7 @@ /// `aliasInfo` (inside `state`) consistent. Newly created operations and /// operations that should be re-analyzed must be stored in `newOps`. virtual LogicalResult run(FuncOp funcOp, BufferizationState &state, + BufferizationAliasInfo &aliasInfo, SmallVector &newOps) = 0; }; @@ -281,9 +285,20 @@ virtual ~DialectBufferizationState() = default; }; -/// BufferizationState keeps track of bufferization state and provides access to -/// the results of the analysis. -struct BufferizationState { +/// BufferizationState keeps track of memory buffers and provides a variety of +/// helper functions for dealing with them. In particular, +/// `BufferizableOpInterface::bufferize` implementation should utilize the +/// following helper functions. +/// +/// * `createAlloc` / `createDealloc` / `createAllocDeallocPair` creates ops +/// that allocate and/or deallocate memref buffers. +/// * `mapBuffer` maps a tensor value to a memref buffer during bufferization. +/// * `lookupBuffer` returns the mapped memref buffer of a given tensor value. +/// * `getResultBuffer` returns the memref buffer for a given tensor OpResult. +/// Based on inplace bufferization decisions of the analysis, it may either +/// directly return a mapped buffer or allocate a new brand new buffer. +class BufferizationState { +public: BufferizationState(ModuleOp moduleOp, const BufferizationOptions &options) : aliasInfo(moduleOp), options(options), builder(moduleOp->getContext()) {} @@ -291,11 +306,21 @@ // BufferizationState should be passed as a reference. BufferizationState(const BufferizationState &) = delete; - /// A function that creates an alloc-dealloc pair. This function may perform - /// additional optimizations such as buffer allocation hoisting. This function - /// calls `allocationFn` and `deallocationFn` to create (de)allocations. - Value createAllocDeallocFn(OpBuilder &builder, Location loc, - Value shapedValue); + /// Creates a memref allocation. + Optional createAlloc(OpBuilder &b, Location loc, MemRefType type, + ArrayRef dynShape); + + /// Creates an alloc-dealloc pair. This function may perform additional + /// optimizations such as buffer allocation hoisting. + Value createAllocDeallocPair(OpBuilder &builder, Location loc, + Value shapedValue); + + /// Creates a memref deallocation. The given memref buffer must have been + /// allocated using `createAlloc`. + void createDealloc(OpBuilder &b, Location loc, Value allocatedBuffer); + + /// Creates a memcpy between two given buffers. + void createMemCpy(OpBuilder &b, Location loc, Value from, Value to); /// Map tensor values to memref buffers. void mapBuffer(ValueRange tensors, ValueRange buffers); @@ -307,6 +332,9 @@ /// Asserts if no buffer is associated. Value lookupBuffer(Value tensor); + /// Return `true` if the given OpResult has been decided to bufferize inplace. + bool isInPlace(OpResult opResult) const; + /// Return `true` if the given value is mapped. bool isMapped(Value value) const; @@ -329,7 +357,24 @@ return static_cast(*dialectState[name]); } - /// `aliasInfo` keeps track of aliasing and equivalent values. + /// Return a reference to the BufferizationOptions. + const BufferizationOptions &getOptions() const { return options; } + + /// Return a reference to the OpBuilder. + OpBuilder &getBuilder() { return builder; } + +private: + friend LogicalResult + runComprehensiveBufferize(FuncOp funcOp, const BufferizationOptions &options, + BufferizationState &state, + const PostAnalysisStepList &extraSteps); + + friend LogicalResult + runComprehensiveBufferize(ModuleOp moduleOp, + const BufferizationOptions &options); + + /// `aliasInfo` keeps track of aliasing and equivalent values. Only internal + /// functions and `runComprehensiveBufferize` may access this object. BufferizationAliasInfo aliasInfo; /// The mapping of tensors to buffers. @@ -428,7 +473,7 @@ auto isaTensor = [](Type t) { return t.isa(); }; if (any_of(op->getOperandTypes(), isaTensor) || any_of(op->getResultTypes(), isaTensor)) - if (!state.options.allowUnknownOps) + if (!state.getOptions().allowUnknownOps) return op->emitError() << "unsupported op with tensors"; for (Region ®ion : op->getRegions()) diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h @@ -35,6 +35,7 @@ /// This analysis can be skipped with `skipAnalysis`. LogicalResult eliminateInitTensors( FuncOp funcOp, BufferizationState &state, + BufferizationAliasInfo &aliasInfo, std::function anchorMatchFunc, std::function rewriteFunc, SmallVector &newOps); @@ -46,6 +47,7 @@ struct InsertSliceAnchoredInitTensorEliminationStep : public InitTensorEliminationStep { LogicalResult run(FuncOp funcOp, BufferizationState &state, + BufferizationAliasInfo &aliasInfo, SmallVector &newOps) override; }; diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h @@ -23,6 +23,7 @@ /// equivalent to their corresponding loop yield values. struct AssertDestinationPassingStyle : public PostAnalysisStep { LogicalResult run(FuncOp funcOp, BufferizationState &state, + BufferizationAliasInfo &aliasInfo, SmallVector &newOps) override; }; diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h @@ -21,6 +21,7 @@ struct InplaceInsertSliceOpAnalysis : public PostAnalysisStep { LogicalResult run(FuncOp funcOp, BufferizationState &state, + BufferizationAliasInfo &aliasInfo, SmallVector &newOps) override; }; diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp @@ -367,7 +367,7 @@ // allocation should be inserted (in the absence of allocation hoisting). setInsertionPointAfter(builder, operandBuffer); // Allocate the result buffer. - Value resultBuffer = createAllocDeallocFn(builder, loc, operandBuffer); + Value resultBuffer = createAllocDeallocPair(builder, loc, operandBuffer); bool skipCopy = false; // Do not copy if the last preceding write of `operand` is an op that does // not write (skipping ops that merely create aliases). E.g., InitTensorOp. @@ -389,8 +389,7 @@ if (!skipCopy) { // The copy happens right before the op that is bufferized. builder.setInsertionPoint(op); - options.allocationFns->memCpyFn(builder, loc, operandBuffer, - resultBuffer); + createMemCpy(builder, loc, operandBuffer, resultBuffer); } return resultBuffer; } @@ -420,7 +419,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::bufferize(Operation *op, BufferizationState &state) { - OpBuilder &b = state.builder; + OpBuilder &b = state.getBuilder(); // Check if op has tensor results or operands. auto isaTensor = [](Type t) { return t.isa(); }; @@ -443,7 +442,7 @@ } // `op` is an unbufferizable tensor op. - if (!state.options.allowUnknownOps) + if (!state.getOptions().allowUnknownOps) return op->emitError() << "unsupported op with tensors"; // Replace all OpOperands with "to-tensor casted" bufferized values. @@ -550,7 +549,7 @@ /// `shapedValue.getDefiningOp` (or at the top of the block in case of a /// bbArg) and the DeallocOp is at the end of the block. Value mlir::linalg::comprehensive_bufferize::BufferizationState:: - createAllocDeallocFn(OpBuilder &b, Location loc, Value shapedValue) { + createAllocDeallocPair(OpBuilder &b, Location loc, Value shapedValue) { // Take a guard before anything else. OpBuilder::InsertionGuard g(b); @@ -561,8 +560,7 @@ // Note: getAllocationTypeAndShape also sets the insertion point. MemRefType allocMemRefType = getAllocationTypeAndShape(b, loc, shapedValue, dynShape); - Optional allocated = - options.allocationFns->allocationFn(b, loc, allocMemRefType, dynShape); + Optional allocated = createAlloc(b, loc, allocMemRefType, dynShape); // TODO: For now just assert the value is returned. Eventually need to // error-propagate. assert(allocated && "allocation failed"); @@ -573,10 +571,29 @@ // 2. Create memory deallocation. b.setInsertionPoint(allocated.getValue().getParentBlock()->getTerminator()); - options.allocationFns->deallocationFn(b, loc, allocated.getValue()); + createDealloc(b, loc, allocated.getValue()); return casted; } +/// Create a memref allocation. +Optional +mlir::linalg::comprehensive_bufferize::BufferizationState::createAlloc( + OpBuilder &b, Location loc, MemRefType type, ArrayRef dynShape) { + return options.allocationFns->allocationFn(b, loc, type, dynShape); +} + +/// Create a memref deallocation. +void mlir::linalg::comprehensive_bufferize::BufferizationState::createDealloc( + OpBuilder &b, Location loc, Value allocatedBuffer) { + return options.allocationFns->deallocationFn(b, loc, allocatedBuffer); +} + +/// Create a memory copy between two memref buffers. +void mlir::linalg::comprehensive_bufferize::BufferizationState::createMemCpy( + OpBuilder &b, Location loc, Value from, Value to) { + return options.allocationFns->memCpyFn(b, loc, from, to); +} + //===----------------------------------------------------------------------===// // Bufferization-specific BlockAndValueMapping support with debugging. //===----------------------------------------------------------------------===// @@ -648,9 +665,15 @@ bool mlir::linalg::comprehensive_bufferize::BufferizationState::isMapped( Value value) const { + assert(value.getType().isa() && "unexpected non-tensor type"); return mapping.contains(value); } +bool mlir::linalg::comprehensive_bufferize::BufferizationState::isInPlace( + OpResult opResult) const { + return aliasInfo.isInPlace(opResult); +} + void mlir::linalg::comprehensive_bufferize::BufferizationState::markOpObsolete( Operation *op) { obsoleteOps.push_back(op); diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp @@ -732,7 +732,7 @@ auto runPostAnalysisSteps = [&](const PostAnalysisStepList &steps) { for (const std::unique_ptr &step : steps) { SmallVector newOps; - if (failed(step->run(funcOp, state, newOps))) + if (failed(step->run(funcOp, state, aliasInfo, newOps))) return failure(); // Analyze ops that were created by the PostAnalysisStep. if (failed(inPlaceAnalysis(newOps, aliasInfo, state, domInfo))) diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp @@ -159,8 +159,8 @@ if (initTensorOp->getUses().empty()) return success(); - Value alloc = state.createAllocDeallocFn(b, initTensorOp->getLoc(), - initTensorOp.result()); + Value alloc = state.createAllocDeallocPair(b, initTensorOp->getLoc(), + initTensorOp.result()); state.mapBuffer(initTensorOp.result(), alloc); return success(); } @@ -379,11 +379,11 @@ LogicalResult mlir::linalg::comprehensive_bufferize::linalg_ext:: InitTensorEliminationStep::eliminateInitTensors( FuncOp funcOp, BufferizationState &state, + BufferizationAliasInfo &aliasInfo, std::function anchorMatchFunc, std::function rewriteFunc, SmallVector &newOps) { OpBuilder b(funcOp->getContext()); - BufferizationAliasInfo &aliasInfo = state.aliasInfo; WalkResult status = funcOp->walk([&](Operation *op) { for (OpOperand &operand : op->getOpOperands()) { @@ -474,16 +474,16 @@ LogicalResult mlir::linalg::comprehensive_bufferize::linalg_ext:: InsertSliceAnchoredInitTensorEliminationStep::run( FuncOp funcOp, BufferizationState &state, - SmallVector &newOps) { + BufferizationAliasInfo &aliasInfo, SmallVector &newOps) { return eliminateInitTensors( - funcOp, state, + funcOp, state, aliasInfo, [&](OpOperand &operand) { auto insertSliceOp = dyn_cast(operand.getOwner()); if (!insertSliceOp) return false; // Only inplace bufferized InsertSliceOps are eligible. - if (!state.aliasInfo.isInPlace(insertSliceOp->getOpResult(0))) + if (!aliasInfo.isInPlace(insertSliceOp->getOpResult(0))) return false; return &operand == &insertSliceOp->getOpOperand(0) /*source*/; }, diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp @@ -88,6 +88,7 @@ } LogicalResult run(FuncOp funcOp, BufferizationState &state, + BufferizationAliasInfo &aliasInfo, SmallVector &newOps) override { ModuleBufferizationState &moduleState = getModuleBufferizationState(state); @@ -99,12 +100,12 @@ if (returnVal.get().getType().isa()) for (BlockArgument bbArg : funcOp.getArguments()) if (bbArg.getType().isa()) - if (state.aliasInfo.areEquivalentBufferizedValues(returnVal.get(), - bbArg)) { + if (aliasInfo.areEquivalentBufferizedValues(returnVal.get(), + bbArg)) { moduleState .equivalentFuncArgs[funcOp][returnVal.getOperandNumber()] = bbArg.getArgNumber(); - if (state.options.testAnalysisOnly) + if (state.getOptions().testAnalysisOnly) annotateReturnOp(returnVal, bbArg); } diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp @@ -265,6 +265,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::scf_ext:: AssertDestinationPassingStyle::run(FuncOp funcOp, BufferizationState &state, + BufferizationAliasInfo &aliasInfo, SmallVector &newOps) { LogicalResult status = success(); funcOp->walk([&](scf::YieldOp yieldOp) { @@ -280,8 +281,7 @@ OpOperand &forOperand = forOp.getOpOperandForResult( forOp->getResult(operand.getOperandNumber())); auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand); - if (!state.aliasInfo.areEquivalentBufferizedValues(operand.get(), - bbArg)) { + if (!aliasInfo.areEquivalentBufferizedValues(operand.get(), bbArg)) { // TODO: this could get resolved with copies but it can also turn into // swaps so we need to be careful about order of copies. status = diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp @@ -144,10 +144,10 @@ extractSliceOp.result().getType().cast(); // If not inplaceable, alloc. - bool inplace = state.aliasInfo.isInPlace(extractSliceOp->getResult(0)); + bool inplace = state.isInPlace(extractSliceOp->getResult(0)); Value alloc; if (!inplace) - alloc = state.createAllocDeallocFn(b, loc, extractSliceOp.result()); + alloc = state.createAllocDeallocPair(b, loc, extractSliceOp.result()); // Bufferize to subview. auto subviewMemRefType = @@ -159,15 +159,12 @@ Value subView = b.create( loc, subviewMemRefType, srcMemref, extractSliceOp.getMixedOffsets(), extractSliceOp.getMixedSizes(), extractSliceOp.getMixedStrides()); - // Insert new alias. - state.aliasInfo.insertNewBufferAlias(subView, srcMemref); /// If not inplaceable, copy. if (!inplace) { // Do not copy if the copied data is never read. if (isValueRead(extractSliceOp.result())) - state.options.allocationFns->memCpyFn(b, extractSliceOp.getLoc(), - subView, alloc); + state.createMemCpy(b, extractSliceOp.getLoc(), subView, alloc); subView = alloc; } @@ -421,8 +418,7 @@ insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides()); // Copy tensor. Value srcMemref = state.lookupBuffer(insertSliceOp.source()); - state.options.allocationFns->memCpyFn(b, insertSliceOp.getLoc(), - srcMemref, subView); + state.createMemCpy(b, insertSliceOp.getLoc(), srcMemref, subView); } state.mapBuffer(insertSliceOp.result(), dstMemref); @@ -437,6 +433,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::tensor_ext:: InplaceInsertSliceOpAnalysis::run(FuncOp funcOp, BufferizationState &state, + BufferizationAliasInfo &aliasInfo, SmallVector &newOps) { auto &tensorState = getTensorBufferizationState(state); funcOp.walk([&](InsertSliceOp insertSliceOp) { @@ -445,9 +442,9 @@ // slice is computed out of place into the inplace full tensor. // - The result is not inplace. This is the case where the whole tensor is // cloned and the clone needs to be updated. - if (isSourceEquivalentToAMatchingInplaceExtractSliceOp(state.aliasInfo, + if (isSourceEquivalentToAMatchingInplaceExtractSliceOp(aliasInfo, insertSliceOp) && - state.aliasInfo.isInPlace(insertSliceOp->getResult(0))) + state.isInPlace(insertSliceOp->getResult(0))) tensorState.insertSliceOpsWithoutCopy.insert(insertSliceOp); }); return success();