diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp @@ -350,9 +350,16 @@ aliasingOperands.size() == 1 && "ops with multiple aliasing OpOperands cannot bufferize out-of-place"); Location loc = op->getLoc(); + // Move insertion point right after `operandBuffer`. That is where the + // allocation should be inserted (in the absence of allocation hoisting). + if (auto bbArg = operandBuffer.dyn_cast()) { + b.setInsertionPointToStart(bbArg.getOwner()); + } else { + b.setInsertionPointAfter(operandBuffer.getDefiningOp()); + } // Allocate the result buffer. Value resultBuffer = - state.allocationFns.createAllocDeallocFn(b, loc, operand, state); + state.allocationFns.createAllocDeallocFn(b, loc, operandBuffer, state); bool skipCopy = false; // Do not copy if the last preceding write of `operand` is an op that does // not write (skipping ops that merely create aliases). E.g., InitTensorOp. @@ -372,7 +379,7 @@ !bufferizesToMemoryRead(*opOperand)) skipCopy = true; if (!skipCopy) { - // Set insertion point now that potential alloc/dealloc are introduced. + // The copy happens right before the op that is bufferized. b.setInsertionPoint(op); state.allocationFns.memCpyFn(b, loc, operandBuffer, resultBuffer); } diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp @@ -756,53 +756,68 @@ // Bufferization-specific scoped alloc/dealloc insertion support. //===----------------------------------------------------------------------===// -/// Helper function that creates a memref::DimOp or tensor::DimOp depending on -/// the type of `source`. -static Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, - int64_t dim) { - if (source.getType().isa()) - return b.createOrFold(loc, source, dim); - if (source.getType().isa()) - return b.createOrFold(loc, source, dim); - llvm_unreachable("Expected MemRefType or TensorType"); +/// Move the insertion point of the given builder to the beginning of a +/// surrounding block as much as possible, while not crossing any allocation +/// hoisting barriers. +static void moveInsertionPointToAllocationHoistingBarrier(OpBuilder &b) { + Operation *op = b.getInsertionBlock()->getParentOp(); + while (op) { + if (auto bufferizableOp = dyn_cast(op)) + if (bufferizableOp.isAllocationHoistingBarrier()) + break; + op = op->getParentOp(); + } + + // FuncOp is an allocation hoisting barrier, so the above loop should never + // run out of parents. + assert( + (op && cast(op).isAllocationHoistingBarrier()) && + "expected traversal to end at allocation hoisting barrier"); + + // TODO: Handle cases where allocation hoisting barrier has more than one + // region or block. + assert(op->getNumRegions() == 1 && + "allocation hoisting barriers with >1 regions not supported"); + assert(op->getRegion(0).getBlocks().size() == 1 && + "allocation hoisting barriers with >1 blocks not supported"); + b.setInsertionPointToStart(&(op->getRegion(0).front())); } /// Compute the type of the `memref` to use for allocating the buffer for /// `shapedValue`. Also returns (by reference in `dynShape`), the value for the -/// dynamic dimensions in the returned `memref` type. The function also sets the -/// insertion point of the builder `b` to the position where the allocation is -/// to be inserted. +/// dynamic dimensions in the returned `memref` type. The function may also set +/// the insertion point to an earlier location, where the allocation should +/// happen ("allocation hoisting"). static MemRefType getAllocationTypeAndShape(OpBuilder &b, Location loc, Value shapedValue, SmallVectorImpl &dynShape) { MemRefType allocMemRefType = getContiguousMemRefType(shapedValue.getType().cast()); - if (auto bbArg = shapedValue.dyn_cast()) { - b.setInsertionPointToStart(bbArg.getOwner()); - loc = bbArg.getOwner()->getParentOp()->getLoc(); - } else { - b.setInsertionPoint(shapedValue.getDefiningOp()); - loc = shapedValue.getDefiningOp()->getLoc(); - } // Compute the dynamic part of the shape. - bool foundDynamicShapes = false; + bool reifiedShapes = false; if (auto rankedOp = dyn_cast_or_null( shapedValue.getDefiningOp())) { ReifiedRankedShapedTypeDims resultDims; if (succeeded(rankedOp.reifyResultShapes(b, resultDims))) { - foundDynamicShapes = true; + reifiedShapes = true; OpResult resultValue = shapedValue.dyn_cast(); auto &shape = resultDims[resultValue.getResultNumber()]; for (auto dim : enumerate(allocMemRefType.getShape())) - if (dim.value() == ShapedType::kDynamicSize) + if (ShapedType::isDynamic(dim.value())) dynShape.push_back(shape[dim.index()]); } } - if (!foundDynamicShapes) { + + if (!reifiedShapes) { for (auto dim : enumerate(allocMemRefType.getShape())) - if (dim.value() == ShapedType::kDynamicSize) - dynShape.push_back(createOrFoldDimOp(b, loc, shapedValue, dim.index())); + if (ShapedType::isDynamic(dim.value())) { + assert((shapedValue.getType().isa() || + shapedValue.getType().isa()) && + "expected MemRef type"); + dynShape.push_back( + b.create(loc, shapedValue, dim.index())); + } } // If the buffer is statically shaped, try to hoist it to the first enclosing @@ -811,28 +826,9 @@ // calls to LICM and buffer hoisting which will most likely not succeed. // TODO: when packing, allocate a static bounding box which will enable more // hoisting. - if (dynShape.empty()) { - Operation *parent; - if (auto bbArg = shapedValue.dyn_cast()) - parent = bbArg.getOwner()->getParentOp(); - else - parent = shapedValue.getDefiningOp()->getParentOp(); - while (parent) { - if (auto bufferizableOp = dyn_cast(parent)) - if (bufferizableOp.isAllocationHoistingBarrier()) - break; - parent = parent->getParentOp(); - } + if (dynShape.empty()) + moveInsertionPointToAllocationHoistingBarrier(b); - // FuncOp is an allocation hoisting barrier, so the above loop should never - // run out of parents. - assert( - (parent && - cast(parent).isAllocationHoistingBarrier()) && - "expected traversal to end at allocation hoisting barrier"); - - b.setInsertionPointToStart(&(parent->getRegion(0).front())); - } return allocMemRefType; } @@ -2247,6 +2243,7 @@ // Take a guard before anything else. OpBuilder::InsertionGuard g(b); + b.setInsertionPoint(extractSliceOp); LDBG("bufferize: " << *extractSliceOp << '\n'); @@ -2263,9 +2260,6 @@ alloc = createNewAllocDeallocPairForShapedValue( b, loc, extractSliceOp.result(), state); - // Set insertion point now that potential alloc/dealloc are introduced. - b.setInsertionPoint(extractSliceOp); - // Bufferize to subview. auto subviewMemRefType = memref::SubViewOp::inferRankReducedResultType( diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir @@ -168,9 +168,9 @@ -> (tensor, tensor, tensor, tensor) { // Hoisted allocs. - // CHECK: %[[REALLOC_A1:.*]] = memref.alloc // CHECK: %[[REALLOC_A0_2:.*]] = memref.alloc // CHECK: %[[REALLOC_A0:.*]] = memref.alloc + // CHECK: %[[REALLOC_A1:.*]] = memref.alloc // Alloc and copy the whole result tensor. Copy the tensor.extract_slice. // CHECK: linalg.copy(%[[A0]], %[[REALLOC_A0]]