diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -1463,6 +1463,7 @@ } b.setInsertionPoint(allocated.getParentBlock()->getTerminator()); b.create(loc, allocated); + return casted; } @@ -1484,7 +1485,7 @@ BufferizationAliasInfo &aliasInfo) { // Take a guard before anything else. OpBuilder::InsertionGuard g(b); - b.setInsertionPointAfter(op); + b.setInsertionPoint(op); // TODO: provide the proper interface to iterate on OpResults and get the // matching OpOperands. @@ -1532,7 +1533,6 @@ if (!op.hasTensorSemantics()) return op->emitError() << "op does not have tensor semantics"; - b.setInsertionPoint(op); Location loc = op.getLoc(); SmallVector newInputBuffers; newInputBuffers.reserve(op.getNumInputs()); @@ -1551,6 +1551,9 @@ // Clone the newly bufferized op. SmallVector newOperands = newInputBuffers; newOperands.append(newOutputBuffers.begin(), newOutputBuffers.end()); + + // Set insertion point now that potential alloc/dealloc are introduced. + b.setInsertionPoint(op); op.clone(b, loc, /*resultTypes=*/TypeRange{}, newOperands); // Replace the results of the old op with the new output buffers. @@ -1774,11 +1777,10 @@ BufferizationAliasInfo &aliasInfo) { // Take a guard before anything else. OpBuilder::InsertionGuard g(b); - Location loc = forOp.getLoc(); // If inPlace, just forward the buffer. // Otherwise alloc and copy. - b.setInsertionPoint(forOp); + Location loc = forOp.getLoc(); for (OpResult opResult : forOp->getResults()) { if (!opResult.getType().isa()) continue; @@ -1801,8 +1803,13 @@ // read". // TODO: "matching bbArg does not bufferize to a read" is a more general // check. - if (!isInitTensorOp(operand)) + if (!isInitTensorOp(operand)) { + OpBuilder::InsertionGuard g(b); + // Set insertion point now that potential alloc/dealloc are introduced. + // Copy is inserted just before the forOp. + b.setInsertionPoint(forOp); b.create(forOp.getLoc(), operandBuffer, resultBuffer); + } } BlockArgument bbArg = forOp.getRegionIterArgForOpOperand(opOperand); aliasInfo.createAliasInfoEntry(resultBuffer); @@ -1860,6 +1867,7 @@ BufferizationAliasInfo &aliasInfo) { // Take a guard before anything else. OpBuilder::InsertionGuard g(b); + // Cannot insert after returnOp. b.setInsertionPoint(returnOp); assert(isa(returnOp->getParentOp()) && @@ -1884,7 +1892,6 @@ BufferizationAliasInfo &aliasInfo) { // Take a guard before anything else. OpBuilder::InsertionGuard g(b); - b.setInsertionPoint(tiledLoopOp); // Allocate output buffers if needed, forward output tensor args to the // terminator. @@ -1937,7 +1944,10 @@ // TODO: "matching bbArg does not bufferize to a read" is a more general // check. if (!isInitTensorOp(oldOutputTensor)) { - b.setInsertionPointAfter(alloc.getDefiningOp()); + OpBuilder::InsertionGuard g(b); + // Set insertion point now that potential alloc/dealloc are introduced. + // Copy is inserted just before the tiledLoopOp. + b.setInsertionPoint(tiledLoopOp); b.create(loc, outputBuffer, alloc); } outputBuffer = alloc; @@ -2023,11 +2033,10 @@ static LogicalResult bufferize(OpBuilder &b, ExtractSliceOp extractSliceOp, BlockAndValueMapping &bvm, BufferizationAliasInfo &aliasInfo) { - LDBG("bufferize: " << *extractSliceOp << '\n'); - // Take a guard before anything else. OpBuilder::InsertionGuard g(b); - b.setInsertionPoint(extractSliceOp); + + LDBG("bufferize: " << *extractSliceOp << '\n'); Location loc = extractSliceOp.getLoc(); // Bail if source was not bufferized. @@ -2045,6 +2054,9 @@ alloc = createNewAllocDeallocPairForShapedValue( b, loc, extractSliceOp.result(), aliasInfo); + // Set insertion point now that potential alloc/dealloc are introduced. + b.setInsertionPoint(extractSliceOp); + // Bufferize to subview. auto subviewMemRefType = memref::SubViewOp::inferRankReducedResultType( @@ -2071,13 +2083,13 @@ static LogicalResult bufferize(OpBuilder &b, InsertSliceOp insertSliceOp, BlockAndValueMapping &bvm, BufferizationAliasInfo &aliasInfo) { - LDBG("bufferize: " << *insertSliceOp << '\n'); - // Take a guard before anything else. OpBuilder::InsertionGuard g(b); b.setInsertionPoint(insertSliceOp); - Location loc = insertSliceOp.getLoc(); + LDBG("bufferize: " << *insertSliceOp << '\n'); + + Location loc = insertSliceOp.getLoc(); Value dstMemref = lookup(bvm, insertSliceOp.dest()); if (!dstMemref) return failure(); @@ -2092,6 +2104,8 @@ // buffer. Value newDstMemref = createNewAllocDeallocPairForShapedValue( b, loc, insertSliceOp.dest(), aliasInfo); + // Set insertion point now that potential alloc/dealloc are introduced. + b.setInsertionPoint(insertSliceOp); b.create(insertSliceOp.getLoc(), dstMemref, newDstMemref); dstMemref = newDstMemref; } @@ -2137,7 +2151,6 @@ // Take a guard before anything else. OpBuilder::InsertionGuard g(b); b.setInsertionPoint(op); - Location loc = op.getLoc(); if (op.getShapedType().isa()) return failure(); @@ -2156,13 +2169,17 @@ // If transfer_write is not inPlace, allocate a new buffer. Value newInputBuffer; + Location loc = op.getLoc(); if (inPlace != InPlaceSpec::True) { // Alloc a copy for `writeOp.source()`, it will become the result buffer. newInputBuffer = createNewAllocDeallocPairForShapedValue( b, loc, writeOp.source(), aliasInfo); Value v = lookup(bvm, writeOp.source()); - if (!isInitTensorOp(writeOp.source())) + if (!isInitTensorOp(writeOp.source())) { + // Set insertion point now that potential alloc/dealloc are introduced. + b.setInsertionPoint(op); b.create(loc, v, newInputBuffer); + } } else { // InPlace write will result in memref.tensor_load(x) which must // canonicalize away with one of it uses. @@ -2188,6 +2205,7 @@ BufferizationAliasInfo &aliasInfo) { // Take a guard before anything else. OpBuilder::InsertionGuard g(b); + // Cannot create IR past a yieldOp. b.setInsertionPoint(yieldOp); scf::ForOp forOp = dyn_cast(yieldOp->getParentOp()); @@ -2225,7 +2243,9 @@ BufferizationAliasInfo &aliasInfo) { // Take a guard before anything else. OpBuilder::InsertionGuard g(b); + // Cannot create IR past a yieldOp. b.setInsertionPoint(yieldOp); + // No tensors -> success. if (!llvm::any_of(yieldOp.getOperandTypes(), isaTensor)) return success(); diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir @@ -702,3 +702,29 @@ return %r : tensor<256x256xf32> } + +// ----- + +//===----------------------------------------------------------------------===// +// Insert point issue cases. +//===----------------------------------------------------------------------===// + +// Only test IR validity wrt dominance. +// CHECK-LABEL: func @ip +func @ip(%t: tensor<10x20xf32> {linalg.inplaceable = true}, + %x: index, %y: index, %v: vector<5x6xf32>) + -> tensor<10x20xf32> +{ + %c0 = constant 0 : index + %c256 = constant 256 : index + %c257 = constant 257 : index + %r = scf.for %arg0 = %c0 to %c257 step %c256 iter_args(%arg1 = %t) -> (tensor<10x20xf32>) { + %t1 = tensor.extract_slice %arg1[%x, 0] [5, %y] [1, 1] : tensor<10x20xf32> to tensor<5x?xf32> + %t11 = tensor.extract_slice %t1[0, 0] [5, %y] [1, 1] : tensor<5x?xf32> to tensor<5x?xf32> + %t2 = vector.transfer_write %v, %t11[%c0, %c0] : vector<5x6xf32>, tensor<5x?xf32> + %t3 = tensor.insert_slice %t2 into %arg1[%x, 0] [5, %y] [1, 1] : tensor<5x?xf32> into tensor<10x20xf32> + scf.yield %t3 : tensor<10x20xf32> + } + return %r : tensor<10x20xf32> +} +