diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -1615,7 +1615,27 @@ OpBuilder::InsertionGuard g(b); b.setInsertionPoint(castOp); - Type sourceType = lookup(bvm, castOp.source()).getType(); + // If castOp is not inPlace, allocate a new buffer. + auto inPlace = getInPlace(castOp->getResult(0)); + Value newBuffer; + if (inPlace != InPlaceSpec::True) { + Location loc = castOp.getLoc(); + // Alloc a copy for `writeOp.source()`, it will become the result buffer. + newBuffer = createNewAllocDeallocPairForShapedValue(b, loc, castOp.source(), + aliasInfo); + if (!isInitTensorOp(castOp.source())) { + // Set insertion point now that potential alloc/dealloc are introduced. + b.setInsertionPoint(castOp); + b.create(loc, lookup(bvm, castOp.source()), newBuffer); + } + } else { + // InPlace write will result in memref.tensor_load(x) which must + // canonicalize away with one of it uses. + newBuffer = lookup(bvm, castOp.source()); + assert(newBuffer && "missing buffer"); + } + + Type sourceType = newBuffer.getType(); auto rankedMemRefType = sourceType.dyn_cast(); auto unrankedMemRefType = sourceType.dyn_cast(); assert(rankedMemRefType || unrankedMemRefType); @@ -1629,8 +1649,7 @@ : ArrayRef{}; Type memRefType = getContiguousOrUnrankedMemRefType( castOp.getResult().getType(), affineMaps, memorySpace); - Value res = b.create(castOp.getLoc(), memRefType, - lookup(bvm, castOp.source())); + Value res = b.create(castOp.getLoc(), memRefType, newBuffer); aliasInfo.insertNewBufferEquivalence(res, castOp.getResult()); map(bvm, castOp.getResult(), res); return success(); diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir @@ -737,3 +737,21 @@ } return %0 : tensor<128x192xf32> } + +// ----- + +// CHECK-LABEL: func @tensor_cast_not_in_place( +// CHECK-SAME: %[[A:.*]]: memref, %[[B:.*]]: memref +// CHECK: %[[alloc:.*]] = memref.alloc +// CHECK: linalg.copy(%[[A]], %[[alloc]]) +// CHECK: %[[cast:.*]] = memref.cast %[[alloc]] +func @tensor_cast_not_in_place( + %A : tensor {linalg.inplaceable = true}, + %B : tensor, %idx: index) + -> (tensor) +{ + %r0 = tensor.cast %A : tensor to tensor<4xf32> + %r1 = tensor.insert_slice %r0 into %A[%idx][4][1] : tensor<4xf32> into tensor + return %r1 : tensor +} +