diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h @@ -9,6 +9,7 @@ #ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_BUFFERIZABLEOPINTERFACE_H_ #define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_BUFFERIZABLEOPINTERFACE_H_ +#include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Operation.h" @@ -40,6 +41,9 @@ public: explicit BufferizationAliasInfo(Operation *rootOp); + // BufferizationAliasInfo should be passed as a reference. + BufferizationAliasInfo(const BufferizationAliasInfo &) = delete; + /// Add a new entry for `v` in the `aliasInfo` and `equivalentInfo`. In the /// beginning the alias and equivalence sets only contain `v` itself. void createAliasInfoEntry(Value v); @@ -237,14 +241,18 @@ /// the results of the analysis. struct BufferizationState { BufferizationState(BufferizationAliasInfo &aliasInfo, - AllocationCallbacks &allocationFns, - BlockAndValueMapping &tensorToBufferMap) - : aliasInfo(aliasInfo), allocationFns(allocationFns), - tensorToBufferMap(tensorToBufferMap) {} + AllocationCallbacks &allocationFns) + : aliasInfo(aliasInfo), allocationFns(allocationFns) {} + + // BufferizationState should be passed as a reference. + BufferizationState(const BufferizationState &) = delete; /// Map tensor values to memref buffers. void mapBuffer(ValueRange tensors, ValueRange buffers); + /// Map a value to another value. + void mapValue(Value from, Value to); + /// Map a tensor value to a memref buffer. void mapBuffer(Value tensor, Value buffer); @@ -252,6 +260,16 @@ /// Asserts if no buffer is associated. Value lookupBuffer(Value tensor) const; + /// Lookup the value that is associated to the given value. Asserts if no + /// value is associated. + Value lookupValue(Value value) const; + + /// Return `true` if the given value is mapped. + bool isMapped(Value value) const; + + /// Mark `op` as obsolete, so that it is deleted after bufferization. + void markOpObsolete(Operation *op); + /// `aliasInfo` keeps track of aliasing and equivalent values. BufferizationAliasInfo &aliasInfo; @@ -259,8 +277,12 @@ /// ops and memcpy ops. AllocationCallbacks &allocationFns; - /// The mapping of tensors to buffers. - BlockAndValueMapping &tensorToBufferMap; + /// The mapping of tensors to buffers. May also contain mappings of non-tensor + /// values. + BlockAndValueMapping mapping; + + /// Obsolete ops that should be deleted after bufferization. + SmallVector obsoleteOps; }; /// Return the result buffer (memref) for a given OpResult (tensor). Allocate diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp @@ -398,7 +398,19 @@ void mlir::linalg::comprehensive_bufferize::BufferizationState::mapBuffer( ValueRange tensors, ValueRange buffers) { assert(!tensors.empty() && "unexpected empty tensors"); - return tensorToBufferMap.map(tensors, buffers); +#ifndef NDEBUG + for (Value tensor : tensors) { + assert(tensor && "unexpected empty tensor"); + assert(tensor.getType().isa() && "unexpected non-tensor type"); + } + for (Value buffer : buffers) { + assert(buffer && "unexpected empty buffer"); + assert((buffer.getType().isa() || + buffer.getType().isa()) && + "expected that tensor is mapped to memref"); + } +#endif // NDEBUG + return mapping.map(tensors, buffers); } /// Wrapper for better debugging. @@ -406,7 +418,17 @@ Value tensor, Value buffer) { assert(tensor && "unexpected empty tensor"); assert(tensor.getType().isa() && "unexpected non-tensor type"); - return tensorToBufferMap.map(tensor, buffer); + assert(buffer && "unexpected empty buffer"); + assert((buffer.getType().isa() || + buffer.getType().isa()) && + "expected that tensor is mapped to memref"); + return mapping.map(tensor, buffer); +} + +void mlir::linalg::comprehensive_bufferize::BufferizationState::mapValue( + Value from, Value to) { + assert(from && "unexpected empty value"); + return mapping.map(from, to); } /// Wrapper for better debugging. @@ -414,7 +436,7 @@ Value tensor) const { // TODO: if key comes from bbArg, forward. assert(tensor.getType().isa() && "unexpected non-tensor type"); - Value v = tensorToBufferMap.lookupOrNull(tensor); + Value v = mapping.lookupOrNull(tensor); if (!v) { // Dump tensor for easier debugging. @@ -423,5 +445,28 @@ return Value(); } + assert((v.getType().isa() || + v.getType().isa()) && + "expected that tensor is mapped to memref"); return v; } + +Value mlir::linalg::comprehensive_bufferize::BufferizationState::lookupValue( + Value value) const { + Value v = mapping.lookupOrNull(value); + if (!v) { + llvm_unreachable("tensor is not mapped"); + return Value(); + } + return v; +} + +bool mlir::linalg::comprehensive_bufferize::BufferizationState::isMapped( + Value value) const { + return mapping.contains(value); +} + +void mlir::linalg::comprehensive_bufferize::BufferizationState::markOpObsolete( + Operation *op) { + obsoleteOps.push_back(op); +} diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp @@ -1677,12 +1677,16 @@ // Bufferization phase. if (!options.testAnalysisOnly) { - BlockAndValueMapping tensorToBufferMap; - BufferizationState state(aliasInfo, *options.allocationFns, - tensorToBufferMap); + BufferizationState state(aliasInfo, *options.allocationFns); + + // Bufferize all ops in funcOp. if (failed( bufferizeFuncOpInternals(funcOp, state, bufferizedFunctionTypes))) return failure(); + + // Erase all obsolete ops. + for (Operation *op : state.obsoleteOps) + op->erase(); } } // Annotate operations if we only want to report the analysis.