diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h @@ -166,4 +166,57 @@ #include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h.inc" +namespace mlir { +namespace linalg { +namespace comprehensive_bufferize { + +/// AllocationHoistingBarrierOnly is an external implementation of +/// BufferizableOpInterface for ops that are (not yet) bufferizable, but are +/// known to be allocation hoisting barriers. All interface methods (except for +/// `isAllocationHoistingBarrier`) are implemented conservatively. +template +struct AllocationHoistingBarrierOnly + : public BufferizableOpInterface::ExternalModel< + AllocationHoistingBarrierOnly, OpTy> { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const { + return true; + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const { + return false; + } + + SmallVector getAliasingOpOperand(Operation *op, + OpResult opResult) const { + return {}; + } + + OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const { + return OpResult(); + } + + BufferRelation bufferRelation(Operation *op, OpOperand &opOperand) const { + return BufferRelation::None; + } + + bool isWritable(Operation *op, Value value) const { return false; } + + LogicalResult bufferize(Operation *op, OpBuilder &b, + BlockAndValueMapping &bvm, + BufferizationAliasInfo &aliasInfo, + AllocationCallbacks &allocationFn) const { + auto isaTensor = [](Type t) { return t.isa(); }; + if (any_of(op->getOperandTypes(), isaTensor) || + any_of(op->getResultTypes(), isaTensor)) + return op->emitError() << "unsupported op with tensors"; + return success(); + } + + bool isAllocationHoistingBarrier(Operation *op) const { return true; } +}; + +} // namespace comprehensive_bufferize +} // namespace linalg +} // namespace mlir + #endif // MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_BUFFERIZABLEOPINTERFACE_H_ diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td @@ -203,6 +203,22 @@ /*defaultImplementation=*/[{ return value.isa(); }] + >, + InterfaceMethod< + /*desc=*/[{ + Return `true` if the op is an allocation hoisting barrier. Buffer + allocations will never be beyond such ops. E.g., ops with certain + parallel semantics may be allocation hoisting barriers. The majority + of ops, however, is not a barrier. Therefore, this method returns + `false` by default. + }], + /*retType=*/"bool", + /*methodName=*/"isAllocationHoistingBarrier", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return false; + }] > ]; diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp @@ -843,21 +843,6 @@ // Bufferization-specific scoped alloc/dealloc insertion support. //===----------------------------------------------------------------------===// -template -Operation *getFirstParentOfType(Value v) { - Operation *parent; - if (auto bbArg = v.dyn_cast()) - parent = bbArg.getOwner()->getParentOp(); - else - parent = v.getDefiningOp()->getParentOp(); - while (parent) { - if (isa(parent)) - return parent; - parent = parent->getParentOp(); - } - return nullptr; -} - /// Helper function that creates a memref::DimOp or tensor::DimOp depending on /// the type of `source`. static Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, @@ -909,17 +894,31 @@ // If the buffer is statically shaped, try to hoist it to the first enclosing // parallel region. - // TODO: this concept of parallel region and threadlocal needs interfaces. // TODO: also hoist in the dynamic case. For now this relies on subsequent // calls to LICM and buffer hoisting which will most likely not succeed. // TODO: when packing, allocate a static bounding box which will enable more // hoisting. if (dynShape.empty()) { - Operation *parent = - getFirstParentOfType(shapedValue); - if (parent) - b.setInsertionPointToStart(&(parent->getRegion(0).front())); + Operation *parent; + if (auto bbArg = shapedValue.dyn_cast()) + parent = bbArg.getOwner()->getParentOp(); + else + parent = shapedValue.getDefiningOp()->getParentOp(); + while (parent) { + if (auto bufferizableOp = dyn_cast(parent)) + if (bufferizableOp.isAllocationHoistingBarrier()) + break; + parent = parent->getParentOp(); + } + + // FuncOp is an allocation hoisting barrier, so the above loop should never + // run out of parents. + assert( + (parent && + cast(parent).isAllocationHoistingBarrier()) && + "expected traversal to end at allocation hoisting barrier"); + + b.setInsertionPointToStart(&(parent->getRegion(0).front())); } return allocMemRefType; } @@ -2239,6 +2238,8 @@ return true; } + bool isAllocationHoistingBarrier(Operation *op) const { return true; } + LogicalResult bufferize(Operation *op, OpBuilder &b, BlockAndValueMapping &bvm, BufferizationAliasInfo &aliasInfo, @@ -3223,6 +3224,13 @@ registry.addOpInterface(); + // Ops that are not bufferizable but are allocation hoisting barriers. + registry.addOpInterface>(); + registry.addOpInterface>(); + registry.addOpInterface>(); + // Register all Linalg structured ops. `LinalgOp` is an interface and it is // not possible to attach an external interface to an existing interface. // Therefore, attach the `BufferizableOpInterface` to all ops one-by-one. diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp @@ -33,7 +33,7 @@ registry .insert(); + arith::ArithmeticDialect, StandardOpsDialect, AffineDialect>(); registerBufferizableOpInterfaceExternalModels(registry); } };