diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h @@ -304,29 +304,29 @@ /// Determine which OpOperand* will alias with `result` if the op is /// bufferized in place. Return an empty vector if the op is not bufferizable. - SmallVector getAliasingOpOperand(OpResult result); + SmallVector getAliasingOpOperand(OpResult result) const; /// Determine which OpResult will alias with `opOperand` if the op is /// bufferized in place. Return an empty OpResult if the op is not /// bufferizable. - OpResult getAliasingOpResult(OpOperand &opOperand); + OpResult getAliasingOpResult(OpOperand &opOperand) const; /// Return true if `opOperand` bufferizes to a memory read. Return `true` if /// the op is not bufferizable. - bool bufferizesToMemoryRead(OpOperand &opOperand); + bool bufferizesToMemoryRead(OpOperand &opOperand) const; /// Return true if `opOperand` bufferizes to a memory write. Return true` if /// the op is not bufferizable. - bool bufferizesToMemoryWrite(OpOperand &opOperand); + bool bufferizesToMemoryWrite(OpOperand &opOperand) const; /// Return true if `opOperand` does neither read nor write but bufferizes to /// an alias. Return false if the op is not bufferizable. - bool bufferizesToAliasOnly(OpOperand &opOperand); + bool bufferizesToAliasOnly(OpOperand &opOperand) const; /// Return true if the given value is read by an op that bufferizes to a /// memory read. Also takes into account ops that create an alias but do not /// read by themselves (e.g., ExtractSliceOp). - bool isValueRead(Value value); + bool isValueRead(Value value) const; /// Starting from `value`, follow the use-def chain in reverse, always /// selecting the aliasing OpOperands. Find and return Values for which @@ -351,9 +351,8 @@ /// In the above example, Values with a star satisfy the condition. When /// starting the traversal from Value 1, the resulting SetVector is: /// { 2, 7, 8, 5 } - llvm::SetVector - findValueInReverseUseDefChain(Value value, - llvm::function_ref condition); + llvm::SetVector findValueInReverseUseDefChain( + Value value, llvm::function_ref condition) const; /// Find the Value of the last preceding write of a given Value. /// @@ -363,33 +362,34 @@ /// /// Note: When reaching an end of the reverse SSA use-def chain, that value /// is returned regardless of whether it is a memory write or not. - Value findLastPrecedingWrite(Value value); + Value findLastPrecedingWrite(Value value) const; /// Creates a memref allocation. Optional createAlloc(OpBuilder &b, Location loc, MemRefType type, - ArrayRef dynShape); + ArrayRef dynShape) const; /// Creates an alloc-dealloc pair. This function may perform additional /// optimizations such as buffer allocation hoisting. Value createAllocDeallocPair(OpBuilder &builder, Location loc, - Value shapedValue); + Value shapedValue) const; /// Creates a memref deallocation. The given memref buffer must have been /// allocated using `createAlloc`. - void createDealloc(OpBuilder &b, Location loc, Value allocatedBuffer); + void createDealloc(OpBuilder &b, Location loc, Value allocatedBuffer) const; /// Creates a memcpy between two given buffers. - void createMemCpy(OpBuilder &b, Location loc, Value from, Value to); + void createMemCpy(OpBuilder &b, Location loc, Value from, Value to) const; /// Replace an op with replacement values. The op is deleted. Tensor OpResults /// must be replaced with memref values. - void replaceOp(RewriterBase &rewriter, Operation *op, ValueRange values); + void replaceOp(RewriterBase &rewriter, Operation *op, + ValueRange values) const; /// Replace an op with a new op. Tensor OpResults must be replaced with memref /// values. template OpTy replaceOpWithNewOp(RewriterBase &rewriter, Operation *op, - Args &&...args) { + Args &&...args) const { Operation *newOp = rewriter.create(op->getLoc(), std::forward(args)...); replaceOp(rewriter, op, newOp->getResults()); @@ -398,7 +398,7 @@ /// Lookup the memref buffer that is associated to the given tensor value. /// Asserts if no buffer is associated. - Value lookupBuffer(RewriterBase &rewriter, Value tensor); + Value lookupBuffer(RewriterBase &rewriter, Value tensor) const; /// Return `true` if the given OpResult has been decided to bufferize inplace. bool isInPlace(OpResult opResult) const; @@ -406,10 +406,20 @@ /// Return the result buffer (memref) for a given OpResult (tensor). Allocate /// a new buffer and copy over data from the existing buffer if out-of-place /// bufferization is necessary. - Value getResultBuffer(RewriterBase &rewriter, OpResult result); + Value getResultBuffer(RewriterBase &rewriter, OpResult result) const; /// Return dialect-specific bufferization state. - template StateT &getDialectState(StringRef name) { + template + Optional getDialectState(StringRef name) const { + auto it = dialectState.find(name); + if (it == dialectState.end()) + return None; + return static_cast(it->getSecond().get()); + } + + /// Return dialect-specific bufferization state or create one if none exists. + template + StateT &getOrCreateDialectState(StringRef name) { // Create state if it does not exist yet. if (!dialectState.count(name)) dialectState[name] = std::make_unique(); @@ -419,15 +429,10 @@ /// Return a reference to the BufferizationOptions. const BufferizationOptions &getOptions() const { return options; } -private: - friend LogicalResult - runComprehensiveBufferize(Operation *op, const BufferizationOptions &options, - BufferizationState &state); - - friend LogicalResult - runComprehensiveBufferize(ModuleOp moduleOp, - std::unique_ptr options); + /// Return a reference to the BufferizationAliasInfo. + BufferizationAliasInfo &getAliasInfo() { return aliasInfo; } +private: /// `aliasInfo` keeps track of aliasing and equivalent values. Only internal /// functions and `runComprehensiveBufferize` may access this object. BufferizationAliasInfo aliasInfo; @@ -441,17 +446,17 @@ /// Bufferize all ops in the given region. LogicalResult bufferize(RewriterBase &rewriter, Region *region, - BufferizationState &state); + const BufferizationState &state); /// Bufferize all ops in the given block. LogicalResult bufferize(RewriterBase &rewriter, Block *block, - BufferizationState &state); + const BufferizationState &state); /// Bufferize the given op. If the op has no tensor OpOperands/OpResults, this /// function returns immediately. Otherwise, it calls the `bufferize` interface /// method of `BufferizableOpInterface`. LogicalResult bufferize(RewriterBase &rewriter, Operation *op, - BufferizationState &state); + const BufferizationState &state); /// Return a contiguous MemRefType (i.e. with canonical/empty layout map) /// with the same shape as `shapedType` and specified `layout` and @@ -492,38 +497,39 @@ : public BufferizableOpInterface::ExternalModel< AllocationHoistingBarrierOnly, OpTy> { bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, - BufferizationState &state) const { + const BufferizationState &state) const { return true; } bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, - BufferizationState &state) const { + const BufferizationState &state) const { return false; } SmallVector getAliasingOpOperand(Operation *op, OpResult opResult, - BufferizationState &state) const { + const BufferizationState &state) const { return {}; } OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, - BufferizationState &state) const { + const BufferizationState &state) const { return OpResult(); } BufferRelation bufferRelation(Operation *op, OpResult opResult, const BufferizationAliasInfo &aliasInfo, - BufferizationState &state) const { + const BufferizationState &state) const { return BufferRelation::None; } - bool isWritable(Operation *op, Value value, BufferizationState &state) const { + bool isWritable(Operation *op, Value value, + const BufferizationState &state) const { return false; } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - BufferizationState &state) const { + const BufferizationState &state) const { auto isaTensor = [](Type t) { return t.isa(); }; if (any_of(op->getOperandTypes(), isaTensor) || any_of(op->getResultTypes(), isaTensor)) diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td @@ -33,7 +33,7 @@ /*retType=*/"bool", /*methodName=*/"bufferizesToMemoryRead", /*args=*/(ins "OpOperand &":$opOperand, - "BufferizationState &":$state), + "const BufferizationState &":$state), /*methodBody=*/"", /*defaultImplementation=*/[{ // Does not have to be implemented for ops without tensor OpOperands. @@ -62,7 +62,7 @@ /*retType=*/"bool", /*methodName=*/"bufferizesToMemoryWrite", /*args=*/(ins "OpOperand &":$opOperand, - "BufferizationState &":$state), + "const BufferizationState &":$state), /*methodBody=*/"", /*defaultImplementation=*/[{ // Does not have to be implemented for ops without tensor OpOperands. @@ -85,7 +85,7 @@ /*retType=*/"bool", /*methodName=*/"isMemoryWrite", /*args=*/(ins "OpResult":$opResult, - "BufferizationState &":$state), + "const BufferizationState &":$state), /*methodBody=*/"", /*defaultImplementation=*/[{ auto bufferizableOp = @@ -116,7 +116,7 @@ /*retType=*/"bool", /*methodName=*/"mustBufferizeInPlace", /*args=*/(ins "OpResult":$opResult, - "BufferizationState &":$state), + "const BufferizationState &":$state), /*methodBody=*/"", /*defaultImplementation=*/[{ return false; @@ -131,7 +131,7 @@ /*retType=*/"OpResult", /*methodName=*/"getAliasingOpResult", /*args=*/(ins "OpOperand &":$opOperand, - "BufferizationState &":$state), + "const BufferizationState &":$state), /*methodBody=*/"", /*defaultImplementation=*/[{ // Does not have to be implemented for ops without tensor OpOperands. @@ -155,7 +155,7 @@ /*retType=*/"SmallVector", /*methodName=*/"getAliasingOpOperand", /*args=*/(ins "OpResult":$opResult, - "BufferizationState &":$state), + "const BufferizationState &":$state), /*methodBody=*/"", /*defaultImplementation=*/[{ assert(opResult.getType().isa() && @@ -188,7 +188,7 @@ /*methodName=*/"bufferRelation", /*args=*/(ins "OpResult":$opResult, "const BufferizationAliasInfo &":$aliasInfo, - "BufferizationState &":$state), + "const BufferizationState &":$state), /*methodBody=*/"", /*defaultImplementation=*/[{ // Does not have to be implemented for ops without tensor OpResults @@ -210,7 +210,7 @@ /*retType=*/"LogicalResult", /*methodName=*/"bufferize", /*args=*/(ins "RewriterBase &":$rewriter, - "BufferizationState &":$state), + "const BufferizationState &":$state), /*methodBody=*/"", /*defaultImplementation=*/[{ llvm_unreachable("bufferize not implemented"); @@ -236,7 +236,7 @@ /*retType=*/"bool", /*methodName=*/"isWritable", /*args=*/(ins "Value":$value, - "BufferizationState &":$state), + "const BufferizationState &":$state), /*methodBody=*/"", /*defaultImplementation=*/[{ return value.isa(); @@ -275,7 +275,7 @@ /*methodName=*/"isNotConflicting", /*args=*/(ins "OpOperand *":$uRead, "OpOperand *":$uWrite, - "BufferizationState &":$state, + "const BufferizationState &":$state, "const BufferizationAliasInfo &":$aliasInfo), /*methodBody=*/"", /*defaultImplementation=*/[{ @@ -292,7 +292,7 @@ /// /// Examples of such ops are `tensor.extract_slice` and `tensor.cast`. bool bufferizesToAliasOnly(OpOperand &opOperand, - BufferizationState &state) { + const BufferizationState &state) { auto bufferizableOp = cast(getOperation()); return !bufferizableOp.bufferizesToMemoryRead(opOperand, state) diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp @@ -24,7 +24,7 @@ : public BufferizableOpInterface::ExternalModel { LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - BufferizationState &state) const { + const BufferizationState &state) const { auto constantOp = cast(op); assert(constantOp.getType().dyn_cast() && "not a constant ranked tensor"); @@ -40,7 +40,8 @@ return success(); } - bool isWritable(Operation *op, Value value, BufferizationState &state) const { + bool isWritable(Operation *op, Value value, + const BufferizationState &state) const { // Memory locations returned by memref::GetGlobalOp may not be written to. assert(value.isa()); return false; diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp @@ -199,7 +199,7 @@ /// in place. Return an empty vector if the op is not bufferizable. SmallVector mlir::linalg::comprehensive_bufferize::BufferizationState::getAliasingOpOperand( - OpResult result) { + OpResult result) const { if (Operation *op = result.getDefiningOp()) if (auto bufferizableOp = dyn_cast(op)) return bufferizableOp.getAliasingOpOperand(result, *this); @@ -210,7 +210,7 @@ /// in place. Return an empty OpResult if the op is not bufferizable. OpResult mlir::linalg::comprehensive_bufferize::BufferizationState::getAliasingOpResult( - OpOperand &opOperand) { + OpOperand &opOperand) const { if (auto bufferizableOp = dyn_cast(opOperand.getOwner())) return bufferizableOp.getAliasingOpResult(opOperand, *this); @@ -220,7 +220,7 @@ /// Return true if `opOperand` bufferizes to a memory read. Return `true` if the /// op is not bufferizable. bool mlir::linalg::comprehensive_bufferize::BufferizationState:: - bufferizesToMemoryRead(OpOperand &opOperand) { + bufferizesToMemoryRead(OpOperand &opOperand) const { if (auto bufferizableOp = dyn_cast(opOperand.getOwner())) return bufferizableOp.bufferizesToMemoryRead(opOperand, *this); @@ -233,7 +233,7 @@ /// Return true if `opOperand` bufferizes to a memory write. Return /// `true` if the op is not bufferizable. bool mlir::linalg::comprehensive_bufferize::BufferizationState:: - bufferizesToMemoryWrite(OpOperand &opOperand) { + bufferizesToMemoryWrite(OpOperand &opOperand) const { if (auto bufferizableOp = dyn_cast(opOperand.getOwner())) return bufferizableOp.bufferizesToMemoryWrite(opOperand, *this); @@ -246,7 +246,7 @@ /// Return true if `opOperand` does neither read nor write but bufferizes to an /// alias. Return false if the op is not bufferizable. bool mlir::linalg::comprehensive_bufferize::BufferizationState:: - bufferizesToAliasOnly(OpOperand &opOperand) { + bufferizesToAliasOnly(OpOperand &opOperand) const { if (auto bufferizableOp = dyn_cast(opOperand.getOwner())) return bufferizableOp.bufferizesToAliasOnly(opOperand, *this); @@ -260,7 +260,7 @@ /// read. Also takes into account ops that create an alias but do not read by /// themselves (e.g., ExtractSliceOp). bool mlir::linalg::comprehensive_bufferize::BufferizationState::isValueRead( - Value value) { + Value value) const { SmallVector workingSet; for (OpOperand &use : value.getUses()) workingSet.push_back(&use); @@ -282,10 +282,9 @@ // the aliasing OpOperands. Find and return Values for which `condition` // evaluates to true. OpOperands of such matching Values are not traversed any // further. -llvm::SetVector -mlir::linalg::comprehensive_bufferize::BufferizationState:: - findValueInReverseUseDefChain(Value value, - llvm::function_ref condition) { +llvm::SetVector mlir::linalg::comprehensive_bufferize:: + BufferizationState::findValueInReverseUseDefChain( + Value value, llvm::function_ref condition) const { llvm::SetVector result, workingSet; workingSet.insert(value); @@ -312,7 +311,7 @@ // Find the Value of the last preceding write of a given Value. Value mlir::linalg::comprehensive_bufferize::BufferizationState:: - findLastPrecedingWrite(Value value) { + findLastPrecedingWrite(Value value) const { SetVector result = findValueInReverseUseDefChain(value, [&](Value value) { Operation *op = value.getDefiningOp(); @@ -360,7 +359,7 @@ /// a new buffer and copy over data from the existing buffer if out-of-place /// bufferization is necessary. Value mlir::linalg::comprehensive_bufferize::BufferizationState:: - getResultBuffer(RewriterBase &rewriter, OpResult result) { + getResultBuffer(RewriterBase &rewriter, OpResult result) const { OpBuilder::InsertionGuard guard(rewriter); Operation *op = result.getOwner(); SmallVector aliasingOperands = getAliasingOpOperand(result); @@ -424,7 +423,7 @@ } void mlir::linalg::comprehensive_bufferize::BufferizationState::replaceOp( - RewriterBase &rewriter, Operation *op, ValueRange values) { + RewriterBase &rewriter, Operation *op, ValueRange values) const { OpBuilder::InsertionGuard g(rewriter); // Replace all OpResults with the given values. @@ -454,7 +453,7 @@ } LogicalResult mlir::linalg::comprehensive_bufferize::bufferize( - RewriterBase &rewriter, Region *region, BufferizationState &state) { + RewriterBase &rewriter, Region *region, const BufferizationState &state) { for (Block &block : *region) if (failed(bufferize(rewriter, &block, state))) return failure(); @@ -462,7 +461,7 @@ } LogicalResult mlir::linalg::comprehensive_bufferize::bufferize( - RewriterBase &rewriter, Block *block, BufferizationState &state) { + RewriterBase &rewriter, Block *block, const BufferizationState &state) { // Ops may get deleted during the traversal, so do not iterate over `block` // directly. SmallVector ops; @@ -476,7 +475,7 @@ } LogicalResult mlir::linalg::comprehensive_bufferize::bufferize( - RewriterBase &rewriter, Operation *op, BufferizationState &state) { + RewriterBase &rewriter, Operation *op, const BufferizationState &state) { // Check if op has tensor results or operands. auto isaTensor = [](Type t) { return t.isa(); }; bool hasTensorResult = any_of(op->getResultTypes(), isaTensor); @@ -592,7 +591,8 @@ /// `shapedValue.getDefiningOp` (or at the top of the block in case of a /// bbArg) and the DeallocOp is at the end of the block. Value mlir::linalg::comprehensive_bufferize::BufferizationState:: - createAllocDeallocPair(OpBuilder &b, Location loc, Value shapedValue) { + createAllocDeallocPair(OpBuilder &b, Location loc, + Value shapedValue) const { // Take a guard before anything else. OpBuilder::InsertionGuard g(b); @@ -621,19 +621,20 @@ /// Create a memref allocation. Optional mlir::linalg::comprehensive_bufferize::BufferizationState::createAlloc( - OpBuilder &b, Location loc, MemRefType type, ArrayRef dynShape) { + OpBuilder &b, Location loc, MemRefType type, + ArrayRef dynShape) const { return options.allocationFns->allocationFn(b, loc, type, dynShape); } /// Create a memref deallocation. void mlir::linalg::comprehensive_bufferize::BufferizationState::createDealloc( - OpBuilder &b, Location loc, Value allocatedBuffer) { + OpBuilder &b, Location loc, Value allocatedBuffer) const { return options.allocationFns->deallocationFn(b, loc, allocatedBuffer); } /// Create a memory copy between two memref buffers. void mlir::linalg::comprehensive_bufferize::BufferizationState::createMemCpy( - OpBuilder &b, Location loc, Value from, Value to) { + OpBuilder &b, Location loc, Value from, Value to) const { return options.allocationFns->memCpyFn(b, loc, from, to); } @@ -649,7 +650,7 @@ } Value mlir::linalg::comprehensive_bufferize::BufferizationState::lookupBuffer( - RewriterBase &rewriter, Value tensor) { + RewriterBase &rewriter, Value tensor) const { assert(tensor.getType().isa() && "unexpected non-tensor type"); // Replace "%t = to_tensor %m" with %m. diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp @@ -40,18 +40,18 @@ : public BufferizableOpInterface::ExternalModel { bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, - BufferizationState &state) const { + const BufferizationState &state) const { // It is unknown whether the resulting MemRef will be read or not. return true; } OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, - BufferizationState &state) const { + const BufferizationState &state) const { return OpResult(); } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - BufferizationState &state) const { + const BufferizationState &state) const { auto toMemrefOp = cast(op); // Fold to_memref(to_tensor(x)) to x. @@ -86,11 +86,12 @@ : public BufferizableOpInterface::ExternalModel { LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - BufferizationState &state) const { + const BufferizationState &state) const { return success(); } - bool isWritable(Operation *op, Value value, BufferizationState &state) const { + bool isWritable(Operation *op, Value value, + const BufferizationState &state) const { // It is unknown whether the MemRef operand is writable or not. return false; } diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp @@ -661,7 +661,7 @@ IRRewriter rewriter(op->getContext()); DominanceInfo domInfo(op); - BufferizationAliasInfo &aliasInfo = state.aliasInfo; + BufferizationAliasInfo &aliasInfo = state.getAliasInfo(); if (failed(checkAliasInfoConsistency(op, domInfo, state, aliasInfo))) return failure(); diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp @@ -24,7 +24,7 @@ /// Generic conversion for any LinalgOp on tensors. static LogicalResult bufferizeLinalgOp(RewriterBase &rewriter, LinalgOp op, - BufferizationState &state) { + const BufferizationState &state) { // Take a guard before anything else. OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(op); @@ -142,13 +142,13 @@ : public BufferizableOpInterface::ExternalModel, OpTy> { bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, - BufferizationState &state) const { + const BufferizationState &state) const { auto genericOp = cast(op); return genericOp.payloadUsesValueFromOperand(&opOperand); } bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, - BufferizationState &state) const { + const BufferizationState &state) const { auto bufferizableOp = cast(op); return static_cast( bufferizableOp.getAliasingOpResult(opOperand, state)); @@ -156,7 +156,7 @@ SmallVector getAliasingOpOperand(Operation *op, OpResult opResult, - BufferizationState &state) const { + const BufferizationState &state) const { auto genericOp = cast(op); DenseMap pairs = computeAliasingPairs(genericOp); for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) @@ -166,7 +166,7 @@ } OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, - BufferizationState &state) const { + const BufferizationState &state) const { auto genericOp = cast(op); DenseMap pairs = computeAliasingPairs(genericOp); return pairs[&opOperand]; @@ -174,12 +174,12 @@ BufferRelation bufferRelation(Operation *op, OpResult opResult, const BufferizationAliasInfo &aliasInfo, - BufferizationState &state) const { + const BufferizationState &state) const { return BufferRelation::Equivalent; } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - BufferizationState &state) const { + const BufferizationState &state) const { return bufferizeLinalgOp(rewriter, cast(op), state); } }; @@ -188,13 +188,13 @@ : public BufferizableOpInterface::ExternalModel { bool isMemoryWrite(Operation *op, OpResult opResult, - BufferizationState &state) const { + const BufferizationState &state) const { // InitTensorOps allocate but do not write. return false; } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - BufferizationState &state) const { + const BufferizationState &state) const { auto initTensorOp = cast(op); // The InitTensorOp may have been eliminated. @@ -212,7 +212,7 @@ : public BufferizableOpInterface::ExternalModel { bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, - BufferizationState &state) const { + const BufferizationState &state) const { // TiledLoop alone doesn't bufferize to a memory read, one of the uses of // its matching bbArg may. auto tiledLoopOp = cast(op); @@ -220,7 +220,7 @@ } bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, - BufferizationState &state) const { + const BufferizationState &state) const { // TiledLoop alone doesn't bufferize to a memory write, one of the uses of // its matching bbArg may. auto bufferizableOp = cast(op); @@ -229,18 +229,19 @@ } OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, - BufferizationState &state) const { + const BufferizationState &state) const { auto tiledLoopOp = cast(op); return tiledLoopOp.getTiedOpResult(opOperand); } BufferRelation bufferRelation(Operation *op, OpResult opResult, const BufferizationAliasInfo &aliasInfo, - BufferizationState &state) const { + const BufferizationState &state) const { return BufferRelation::Equivalent; } - bool isWritable(Operation *op, Value value, BufferizationState &state) const { + bool isWritable(Operation *op, Value value, + const BufferizationState &state) const { // Interestingly, linalg::TiledLoopOp's bbArg can **always** be viewed // inplace from the perspective of ops nested under: // 1. Either the matching iter operand is not bufferized inplace and an @@ -253,7 +254,7 @@ bool isAllocationHoistingBarrier(Operation *op) const { return true; } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - BufferizationState &state) const { + const BufferizationState &state) const { auto tiledLoopOp = cast(op); // Compute new inputs, outputs and results. @@ -355,22 +356,22 @@ : public BufferizableOpInterface::ExternalModel { bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, - BufferizationState &state) const { + const BufferizationState &state) const { return true; } bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, - BufferizationState &state) const { + const BufferizationState &state) const { return false; } OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, - BufferizationState &state) const { + const BufferizationState &state) const { return OpResult(); } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - BufferizationState &state) const { + const BufferizationState &state) const { auto yieldOp = cast(op); if (!yieldOp->getParentOfType()) diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp @@ -34,9 +34,20 @@ }; } // namespace +/// Get ModuleBufferizationState. +static const ModuleBufferizationState & +getModuleBufferizationState(const BufferizationState &state) { + Optional maybeState = + state.getDialectState( + StandardOpsDialect::getDialectNamespace()); + assert(maybeState.hasValue() && "ModuleBufferizationState does not exist"); + return **maybeState; +} + +/// Get or create ModuleBufferizationState. static ModuleBufferizationState & getModuleBufferizationState(BufferizationState &state) { - return state.getDialectState( + return state.getOrCreateDialectState( StandardOpsDialect::getDialectNamespace()); } @@ -471,19 +482,25 @@ /// Return the index of the bbArg in the given FuncOp that is equivalent to the /// specified return value (if any). static Optional -getEquivalentFuncArgIdx(FuncOp funcOp, ModuleBufferizationState &state, +getEquivalentFuncArgIdx(FuncOp funcOp, const ModuleBufferizationState &state, int64_t returnValIdx) { - if (!state.equivalentFuncArgs[funcOp].count(returnValIdx)) + if (!state.equivalentFuncArgs.count(funcOp)) + // No equivalence info stores for funcOp. + return None; + + const DenseMap &equivFuncArgs = + state.equivalentFuncArgs.lookup(funcOp); + if (!equivFuncArgs.count(returnValIdx)) // Return value has no equivalent bbArg. return None; - return state.equivalentFuncArgs[funcOp][returnValIdx]; + return equivFuncArgs.lookup(returnValIdx); } struct CallOpInterface : public BufferizableOpInterface::ExternalModel { bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, - BufferizationState &state) const { + const BufferizationState &state) const { // CallOpInterface alone doesn't bufferize to a memory read, one of the uses // of the matching bbArg may. It is the responsibility of the caller to // inspect bbArgs. In the absence of a BufferizationAliasInfo, we need to be @@ -492,7 +509,7 @@ } OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, - BufferizationState &state) const { + const BufferizationState &state) const { // CallOpInterface is special, it needs to wait for the callee to be // bufferized and needs to inspect the BufferAliasInfo object. It can't // make a proper determination by itself and needs to be conservative. @@ -503,14 +520,15 @@ /// marked inplaceable. For now, it is the responsibility of the `callOp` /// bufferization to allow FuncOp that are inplaceable to write inPlace. LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - BufferizationState &state) const { + const BufferizationState &state) const { CallOp callOp = cast(op); unsigned numResults = callOp.getNumResults(); unsigned numOperands = callOp->getNumOperands(); FuncOp funcOp = getCalledFunction(callOp); assert(isa(callOp.getOperation()) && funcOp && "expected CallOp to a FuncOp"); - ModuleBufferizationState &moduleState = getModuleBufferizationState(state); + const ModuleBufferizationState &moduleState = + getModuleBufferizationState(state); // Result types of the bufferized CallOp. SmallVector resultTypes; @@ -626,22 +644,22 @@ : public BufferizableOpInterface::ExternalModel { bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, - BufferizationState &state) const { + const BufferizationState &state) const { return true; } bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, - BufferizationState &state) const { + const BufferizationState &state) const { return false; } OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, - BufferizationState &state) const { + const BufferizationState &state) const { return OpResult(); } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - BufferizationState &state) const { + const BufferizationState &state) const { auto returnOp = cast(op); assert(isa(returnOp->getParentOp()) && "only support FuncOp parent for ReturnOp"); @@ -662,7 +680,7 @@ struct FuncOpInterface : public BufferizableOpInterface::ExternalModel { LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - BufferizationState &state) const { + const BufferizationState &state) const { auto funcOp = cast(op); // Bufferize function body. @@ -670,11 +688,13 @@ } /// Return `true` if the given function argument is writable. - bool isWritable(Operation *op, Value value, BufferizationState &state) const { + bool isWritable(Operation *op, Value value, + const BufferizationState &state) const { auto funcOp = cast(op); BlockArgument bbArg = value.dyn_cast(); assert(bbArg && "expected BlockArgument"); - ModuleBufferizationState &moduleState = getModuleBufferizationState(state); + const ModuleBufferizationState &moduleState = + getModuleBufferizationState(state); // In a first approximation: // ========================= @@ -720,8 +740,9 @@ } /// Annotate the IR with the result of the analysis. For testing/debugging only. -static void annotateOpsWithBufferizationMarkers(FuncOp funcOp, - BufferizationState &state) { +static void +annotateOpsWithBufferizationMarkers(FuncOp funcOp, + const BufferizationState &state) { auto bufferizableOp = cast(funcOp.getOperation()); for (BlockArgument bbArg : funcOp.getArguments()) if (bbArg.getType().isa()) @@ -733,7 +754,7 @@ IRRewriter rewriter(moduleOp.getContext()); BufferizationState state(moduleOp, *options); ModuleBufferizationState &moduleState = getModuleBufferizationState(state); - BufferizationAliasInfo &aliasInfo = state.aliasInfo; + BufferizationAliasInfo &aliasInfo = state.getAliasInfo(); if (failed(getFuncOpsOrderedByCalls(moduleOp, moduleState.orderedFuncOps, moduleState.callerMap))) diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp @@ -24,7 +24,7 @@ scf::ExecuteRegionOp> { SmallVector getAliasingOpOperand(Operation *op, OpResult opResult, - BufferizationState &state) const { + const BufferizationState &state) const { // ExecuteRegionOps do not have tensor OpOperands. The yielded value can be // any SSA value that is in scope. To allow for use-def chain traversal // through ExecuteRegionOps in the analysis, the corresponding yield value @@ -41,7 +41,7 @@ } bool mustBufferizeInPlace(Operation *op, OpResult opResult, - BufferizationState &state) const { + const BufferizationState &state) const { // ExecuteRegionOp results always bufferize in-place. Since they have no // OpOperands, they are mostly ignored by the analysis once alias sets are // set up. @@ -51,7 +51,7 @@ // TODO: For better bufferization results, this could return `true` only if // there is a memory write in the region. bool isMemoryWrite(Operation *op, OpResult opResult, - BufferizationState &state) const { + const BufferizationState &state) const { // Similar to scf.if, results of this op are always considered memory writes // in the analysis. This is a useful pattern for all ops that have tensor // OpResults but no tensor OpOperands. By default, `isMemoryWrite` is @@ -61,7 +61,7 @@ } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - BufferizationState &state) const { + const BufferizationState &state) const { // TODO: Add bufferization support when needed. scf.execute_region should be // bufferized similar to scf.if. auto executeRegionOp = cast(op); @@ -76,7 +76,7 @@ BufferRelation bufferRelation(Operation *op, OpResult opResult, const BufferizationAliasInfo &aliasInfo, - BufferizationState &state) const { + const BufferizationState &state) const { return BufferRelation::Equivalent; } }; @@ -85,7 +85,7 @@ : public BufferizableOpInterface::ExternalModel { SmallVector getAliasingOpOperand(Operation *op, OpResult opResult, - BufferizationState &state) const { + const BufferizationState &state) const { // IfOps do not have tensor OpOperands. The yielded value can be any SSA // value that is in scope. To allow for use-def chain traversal through // IfOps in the analysis, both corresponding yield values from the then/else @@ -102,7 +102,7 @@ // allowed at the moment, we should never encounter scf.ifs that yield // unmodified tensors. Such scf.yield ops could just fold away. bool isMemoryWrite(Operation *op, OpResult opResult, - BufferizationState &state) const { + const BufferizationState &state) const { // IfOp results are always considered memory writes in the analysis. This // design decision simplifies the analysis considerably. E.g., consider the // following test case: @@ -129,14 +129,14 @@ } bool mustBufferizeInPlace(Operation *op, OpResult opResult, - BufferizationState &state) const { + const BufferizationState &state) const { // IfOp results always bufferize in-place. Since they have no OpOperands, // they are mostly ignored by the analysis once alias sets are set up. return true; } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - BufferizationState &state) const { + const BufferizationState &state) const { auto ifOp = cast(op); // Compute new types of the bufferized scf.if op. @@ -209,7 +209,7 @@ BufferRelation bufferRelation(Operation *op, OpResult opResult, const BufferizationAliasInfo &aliasInfo, - BufferizationState &state) const { + const BufferizationState &state) const { // IfOp results are equivalent to their corresponding yield values if both // yield values are equivalent to each other. auto bufferizableOp = cast(op); @@ -226,7 +226,7 @@ : public BufferizableOpInterface::ExternalModel { bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, - BufferizationState &state) const { + const BufferizationState &state) const { // scf::ForOp alone doesn't bufferize to a memory read, one of the uses of // its matching bbArg may. auto forOp = cast(op); @@ -234,7 +234,7 @@ } bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, - BufferizationState &state) const { + const BufferizationState &state) const { // Tensor iter_args of scf::ForOps are always considered as a write. This is // to simplify the analysis. // TODO: Consider doing sth. like isValueWritten. @@ -242,7 +242,7 @@ } OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, - BufferizationState &state) const { + const BufferizationState &state) const { auto forOp = cast(op); if (!opOperand.get().getType().isa()) return OpResult(); @@ -251,7 +251,7 @@ BufferRelation bufferRelation(Operation *op, OpResult opResult, const BufferizationAliasInfo &aliasInfo, - BufferizationState &state) const { + const BufferizationState &state) const { // ForOp results are equivalent to their corresponding init_args if the // corresponding iter_args and yield values are equivalent. auto forOp = cast(op); @@ -263,7 +263,8 @@ return equivalentYield ? BufferRelation::Equivalent : BufferRelation::None; } - bool isWritable(Operation *op, Value value, BufferizationState &state) const { + bool isWritable(Operation *op, Value value, + const BufferizationState &state) const { // Interestingly, scf::ForOp's bbArg can **always** be viewed // inplace from the perspective of ops nested under: // 1. Either the matching iter operand is not bufferized inplace and an @@ -274,7 +275,7 @@ } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - BufferizationState &state) const { + const BufferizationState &state) const { auto forOp = cast(op); Block *oldLoopBody = &forOp.getLoopBody().front(); @@ -416,22 +417,22 @@ : public BufferizableOpInterface::ExternalModel { bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, - BufferizationState &state) const { + const BufferizationState &state) const { return true; } bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, - BufferizationState &state) const { + const BufferizationState &state) const { return false; } OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, - BufferizationState &state) const { + const BufferizationState &state) const { return OpResult(); } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - BufferizationState &state) const { + const BufferizationState &state) const { auto yieldOp = cast(op); if (!isa( yieldOp->getParentOp())) diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp @@ -27,28 +27,28 @@ : public BufferizableOpInterface::ExternalModel { bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, - BufferizationState &state) const { + const BufferizationState &state) const { return false; } bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, - BufferizationState &state) const { + const BufferizationState &state) const { return false; } OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, - BufferizationState &state) const { + const BufferizationState &state) const { return op->getResult(0); } BufferRelation bufferRelation(Operation *op, OpResult opResult, const BufferizationAliasInfo &aliasInfo, - BufferizationState &state) const { + const BufferizationState &state) const { return BufferRelation::Equivalent; } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - BufferizationState &state) const { + const BufferizationState &state) const { auto castOp = cast(op); Value resultBuffer = state.getResultBuffer(rewriter, castOp->getResult(0)); @@ -78,22 +78,22 @@ : public BufferizableOpInterface::ExternalModel { bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, - BufferizationState &state) const { + const BufferizationState &state) const { return true; } bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, - BufferizationState &state) const { + const BufferizationState &state) const { return false; } OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, - BufferizationState &state) const { + const BufferizationState &state) const { return OpResult(); } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - BufferizationState &state) const { + const BufferizationState &state) const { auto dimOp = cast(op); if (!dimOp.source().getType().isa()) return dimOp.emitError("unranked tensor not supported"); @@ -107,17 +107,17 @@ : public BufferizableOpInterface::ExternalModel { bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, - BufferizationState &state) const { + const BufferizationState &state) const { return false; } bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, - BufferizationState &state) const { + const BufferizationState &state) const { return false; } OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, - BufferizationState &state) const { + const BufferizationState &state) const { return &opOperand == &op->getOpOperand(0) /*source*/ ? op->getResult(0) : OpResult(); @@ -125,12 +125,12 @@ BufferRelation bufferRelation(Operation *op, OpResult opResult, const BufferizationAliasInfo &aliasInfo, - BufferizationState &state) const { + const BufferizationState &state) const { return BufferRelation::None; } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - BufferizationState &state) const { + const BufferizationState &state) const { auto extractSliceOp = cast(op); Location loc = extractSliceOp.getLoc(); Value srcMemref = state.lookupBuffer(rewriter, extractSliceOp.source()); @@ -173,22 +173,22 @@ : public BufferizableOpInterface::ExternalModel { bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, - BufferizationState &state) const { + const BufferizationState &state) const { return true; } bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, - BufferizationState &state) const { + const BufferizationState &state) const { return false; } OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, - BufferizationState &state) const { + const BufferizationState &state) const { return OpResult(); } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - BufferizationState &state) const { + const BufferizationState &state) const { auto extractOp = cast(op); Value srcMemref = state.lookupBuffer(rewriter, extractOp.tensor()); state.replaceOpWithNewOp(rewriter, op, srcMemref, @@ -201,17 +201,17 @@ : public BufferizableOpInterface::ExternalModel { bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, - BufferizationState &state) const { + const BufferizationState &state) const { return true; } bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, - BufferizationState &state) const { + const BufferizationState &state) const { return true; } OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, - BufferizationState &state) const { + const BufferizationState &state) const { assert(&opOperand == &op->getOpOperand(1) /*dest*/ && "expected dest OpOperand"); return op->getOpResult(0); @@ -219,12 +219,12 @@ SmallVector getAliasingOpOperand(Operation *op, OpResult opResult, - BufferizationState &state) const { + const BufferizationState &state) const { return {&op->getOpOperand(1) /*dest*/}; } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - BufferizationState &state) const { + const BufferizationState &state) const { auto insertOp = cast(op); Location loc = insertOp.getLoc(); Value destMemref = @@ -237,7 +237,7 @@ BufferRelation bufferRelation(Operation *op, OpResult opResult, const BufferizationAliasInfo &aliasInfo, - BufferizationState &state) const { + const BufferizationState &state) const { return BufferRelation::Equivalent; } }; @@ -263,8 +263,8 @@ /// Return true if `value` is originating from an ExtractSliceOp that matches /// the given InsertSliceOp. static bool hasMatchingExtractSliceOp(const BufferizationAliasInfo &aliasInfo, - BufferizationState &state, Value value, - InsertSliceOp insertOp) { + const BufferizationState &state, + Value value, InsertSliceOp insertOp) { auto condition = [&](Value val) { if (auto extractOp = val.getDefiningOp()) if (areEquivalentExtractSliceOps(aliasInfo, extractOp, insertOp)) @@ -280,17 +280,17 @@ : public BufferizableOpInterface::ExternalModel { bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, - BufferizationState &state) const { + const BufferizationState &state) const { return true; } bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, - BufferizationState &state) const { + const BufferizationState &state) const { return &opOperand == &op->getOpOperand(1) /*dest*/; } OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, - BufferizationState &state) const { + const BufferizationState &state) const { return &opOperand == &op->getOpOperand(1) /*dest*/ ? op->getResult(0) : OpResult(); @@ -298,12 +298,13 @@ BufferRelation bufferRelation(Operation *op, OpResult opResult, const BufferizationAliasInfo &aliasInfo, - BufferizationState &state) const { + const BufferizationState &state) const { return BufferRelation::Equivalent; } bool isNotConflicting(Operation *op, OpOperand *uRead, - OpOperand *uConflictingWrite, BufferizationState &state, + OpOperand *uConflictingWrite, + const BufferizationState &state, const BufferizationAliasInfo &aliasInfo) const { Operation *readingOp = uRead->getOwner(); Operation *conflictingWritingOp = uConflictingWrite->getOwner(); @@ -380,7 +381,7 @@ } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - BufferizationState &state) const { + const BufferizationState &state) const { // insert_slice ops arise from tiling and bufferizing them out-of-place is // generally a deal breaker. When used with loops, this ends up cloning the // whole tensor on every single iteration and is a symptom of a diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp @@ -21,26 +21,26 @@ : public BufferizableOpInterface::ExternalModel { bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, - BufferizationState &state) const { + const BufferizationState &state) const { assert(opOperand.get().getType().isa() && "only tensor types expected"); return true; } bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, - BufferizationState &state) const { + const BufferizationState &state) const { assert(opOperand.get().getType().isa() && "only tensor types expected"); return false; } OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, - BufferizationState &state) const { + const BufferizationState &state) const { return OpResult(); } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - BufferizationState &state) const { + const BufferizationState &state) const { auto readOp = cast(op); assert(readOp.getShapedType().isa() && "only tensor types expected"); @@ -60,21 +60,21 @@ : public BufferizableOpInterface::ExternalModel { bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, - BufferizationState &state) const { + const BufferizationState &state) const { assert(opOperand.get().getType().isa() && "only tensor types expected"); return true; } bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, - BufferizationState &state) const { + const BufferizationState &state) const { assert(opOperand.get().getType().isa() && "only tensor types expected"); return true; } OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, - BufferizationState &state) const { + const BufferizationState &state) const { assert(opOperand.get().getType().isa() && "only tensor types expected"); return op->getOpResult(0); @@ -82,12 +82,12 @@ BufferRelation bufferRelation(Operation *op, OpResult opResult, const BufferizationAliasInfo &aliasInfo, - BufferizationState &state) const { + const BufferizationState &state) const { return BufferRelation::Equivalent; } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - BufferizationState &state) const { + const BufferizationState &state) const { auto writeOp = cast(op); assert(writeOp.getShapedType().isa() && "only tensor types expected");