diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp @@ -2244,23 +2244,22 @@ LogicalResult bufferize(Operation *op, OpBuilder &b, BufferizationState &state) const { auto extractSliceOp = cast(op); + LDBG("bufferize: " << *extractSliceOp << '\n'); // Take a guard before anything else. OpBuilder::InsertionGuard g(b); b.setInsertionPoint(extractSliceOp); - LDBG("bufferize: " << *extractSliceOp << '\n'); - Location loc = extractSliceOp.getLoc(); - // Bail if source was not bufferized. Value srcMemref = state.lookupBuffer(extractSliceOp.source()); auto srcMemrefType = srcMemref.getType().cast(); auto dstTensorType = extractSliceOp.result().getType().cast(); // If not inplaceable, alloc. + bool inplace = state.aliasInfo.isInPlace(extractSliceOp->getResult(0)); Value alloc; - if (!state.aliasInfo.isInPlace(extractSliceOp->getResult(0))) + if (!inplace) alloc = createNewAllocDeallocPairForShapedValue( b, loc, extractSliceOp.result(), state); @@ -2278,7 +2277,7 @@ state.aliasInfo.insertNewBufferAlias(subView, srcMemref); /// If not inplaceable, copy. - if (alloc) { + if (!inplace) { // Do not copy if the copied data is never read. if (isValueRead(extractSliceOp.result())) state.allocationFns.memCpyFn(b, extractSliceOp.getLoc(), subView, @@ -2374,34 +2373,23 @@ LogicalResult bufferize(Operation *op, OpBuilder &b, BufferizationState &state) const { + // insert_slice ops arise from tiling and bufferizing them out-of-place is + // generally a deal breaker. When used with loops, this ends up cloning the + // whole tensor on every single iteration and is a symptom of a + // catastrophically bad scheduling decision. + // TODO: be very loud about it or even consider failing the pass. auto insertSliceOp = cast(op); + LDBG("bufferize: " << *insertSliceOp << '\n'); // Take a guard before anything else. OpBuilder::InsertionGuard g(b); b.setInsertionPoint(insertSliceOp); - - LDBG("bufferize: " << *insertSliceOp << '\n'); - Location loc = insertSliceOp.getLoc(); - // Since insert_slice arise from tiling and introducing loops, this - // case is generally a deal breaker. When used with loops, this ends up - // cloning the whole tensor on every single iteration and is a symptom - // of a catastrophically bad scheduling decision. - // TODO: be very loud about it or even consider failing the pass. - // Alloc a copy for `insertSliceOp.dest()`, it will become the result - // buffer. + + // When bufferizing out-of-place, `getResultBuffer` allocates. Value dstMemref = getResultBuffer(b, insertSliceOp->getResult(0), state); if (!dstMemref) return failure(); - auto dstMemrefType = dstMemref.getType().cast(); - - Value srcMemref = state.lookupBuffer(insertSliceOp.source()); - auto subviewMemRefType = - memref::SubViewOp::inferRankReducedResultType( - insertSliceOp.getSourceType().getRank(), dstMemrefType, - insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(), - insertSliceOp.getMixedStrides()) - .cast(); // A copy of the source buffer is needed if either: // - The producer of `source` is not inplace. This is the case where a @@ -2409,23 +2397,32 @@ // - The result is not inplace. This is the case where the whole tensor is // cloned and the clone needs to be updated. // TODO: Is this necessary? - if (!isSourceEquivalentToAMatchingInplaceExtractSliceOp(state.aliasInfo, - insertSliceOp) || - !state.aliasInfo.isInPlace(insertSliceOp->getResult(0))) { + bool needCopy = !isSourceEquivalentToAMatchingInplaceExtractSliceOp( + state.aliasInfo, insertSliceOp) || + !state.aliasInfo.isInPlace(insertSliceOp->getResult(0)); + if (needCopy) { LDBG("insert_slice needs extra source copy: " << insertSliceOp.source() << " -> copy\n"); // Take a subview of the dst. + auto dstMemrefType = dstMemref.getType().cast(); + auto subviewMemRefType = + memref::SubViewOp::inferRankReducedResultType( + insertSliceOp.getSourceType().getRank(), dstMemrefType, + insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(), + insertSliceOp.getMixedStrides()) + .cast(); Value subView = b.create( loc, subviewMemRefType, dstMemref, insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides()); // Insert new alias. state.aliasInfo.insertNewBufferAlias(subView, dstMemref); + // Copy tensor. + Value srcMemref = state.lookupBuffer(insertSliceOp.source()); state.allocationFns.memCpyFn(b, insertSliceOp.getLoc(), srcMemref, subView); } state.mapBuffer(insertSliceOp.result(), dstMemref); - return success(); } };