diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -51,6 +51,39 @@ return BufferRelation::Equivalent; } + FailureOr + getBufferType(Operation *op, Value value, const BufferizationOptions &options, + const DenseMap &fixedTypes) const { + auto castOp = cast(op); + auto maybeSrcBufferType = + bufferization::getBufferType(castOp.getSource(), options, fixedTypes); + if (failed(maybeSrcBufferType)) + return failure(); + Attribute memorySpace = maybeSrcBufferType->getMemorySpace(); + + // Note: `getMemRefTypeWithFullyDynamicLayout` returns an unranked memref + // type in case the input is an unranked tensor type. + + // Case 1: Casting an unranked tensor + if (castOp.getSource().getType().isa()) { + // When casting to a ranked tensor, we cannot infer any static offset or + // strides from the source. Assume fully dynamic. + return getMemRefTypeWithFullyDynamicLayout(castOp.getType(), memorySpace); + } + + // Case 2: Casting to an unranked tensor type + if (castOp.getType().isa()) { + return getMemRefTypeWithFullyDynamicLayout(castOp.getType(), memorySpace); + } + + // Case 3: Ranked tensor -> ranked tensor. The offsets and strides do not + // change. + auto rankedResultType = castOp.getType().cast(); + return MemRefType::get( + rankedResultType.getShape(), rankedResultType.getElementType(), + maybeSrcBufferType->cast().getLayout(), memorySpace); + } + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { auto castOp = cast(op); @@ -60,25 +93,19 @@ getBuffer(rewriter, castOp.getSource(), options); if (failed(resultBuffer)) return failure(); - auto sourceMemRefType = resultBuffer->getType().cast(); - TensorType resultTensorType = - castOp.getResult().getType().cast(); - MemRefLayoutAttrInterface layout; - if (auto rankedMemRefType = sourceMemRefType.dyn_cast()) - if (resultTensorType.isa()) - layout = rankedMemRefType.getLayout(); - - // Compute the new memref type. - Type resultMemRefType = getMemRefType(castOp.getResult(), options, layout, - sourceMemRefType.getMemorySpace()); + // Compute the new type. + auto resultMemRefType = + bufferization::getBufferType(castOp.getResult(), options); + if (failed(resultMemRefType)) + return failure(); // Replace the op with a memref.cast. assert(memref::CastOp::areCastCompatible(resultBuffer->getType(), - resultMemRefType) && + *resultMemRefType) && "CallOp::bufferize: cast incompatible"); - replaceOpWithNewBufferizedOp(rewriter, op, resultMemRefType, - *resultBuffer); + replaceOpWithNewBufferizedOp( + rewriter, op, *resultMemRefType, *resultBuffer); return success(); } diff --git a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir --- a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir +++ b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir @@ -925,3 +925,26 @@ } return } + +// ----- + +// This is a regression test. Make sure that bufferization succeeds. + +// CHECK-LABEL: func @regression_cast_in_loop( +func.func @regression_cast_in_loop() -> tensor<2xindex> { + %false = arith.constant false + %c0 = arith.constant 0 : index + %0 = bufferization.alloc_tensor() : tensor<2xindex> + // CHECK: scf.while (%{{.*}} = %{{.*}}) : (memref<2xindex>) -> memref<2xindex> + %1 = scf.while (%arg0 = %0) : (tensor<2xindex>) -> tensor<2xindex> { + scf.condition(%false) %arg0 : tensor<2xindex> + } do { + // CHECK: ^bb0(%{{.*}}: memref<2xindex>): + ^bb0(%arg0: tensor<2xindex>): + %cast = tensor.cast %0 : tensor<2xindex> to tensor + %inserted = tensor.insert %c0 into %cast[%c0] : tensor + %cast_0 = tensor.cast %inserted : tensor to tensor<2xindex> + scf.yield %cast_0 : tensor<2xindex> + } + return %1 : tensor<2xindex> +} diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir --- a/mlir/test/Dialect/Tensor/bufferize.mlir +++ b/mlir/test/Dialect/Tensor/bufferize.mlir @@ -40,8 +40,8 @@ // CHECK-LABEL: func @tensor.cast_from_unranked( // CHECK-SAME: %[[TENSOR:.*]]: tensor<*xf32>) -> tensor<2xf32> { // CHECK: %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] : memref<*xf32> -// CHECK: %[[CASTED_MEMREF:.*]] = memref.cast %[[MEMREF]] : memref<*xf32> to memref<2xf32> -// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[CASTED_MEMREF]] : memref<2xf32> +// CHECK: %[[CASTED_MEMREF:.*]] = memref.cast %[[MEMREF]] : memref<*xf32> to memref<2xf32, strided<[?], offset: ?>> +// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[CASTED_MEMREF]] : memref<2xf32, strided<[?], offset: ?>> // CHECK: return %[[RET]] : tensor<2xf32> func.func @tensor.cast_from_unranked(%arg0: tensor<*xf32>) -> tensor<2xf32> { %0 = tensor.cast %arg0 : tensor<*xf32> to tensor<2xf32> diff --git a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir --- a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir +++ b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir @@ -347,3 +347,26 @@ %1 = tensor.dim %t, %c0 : tensor return %0, %1 : tensor, index } + +// ----- + +// CHECK: #[[$map:.*]] = affine_map<(d0) -> (d0 + 5)> +// CHECK-LABEL: func.func @cast_retains_buffer_layout( +// CHECK-SAME: %[[t:.*]]: memref, %[[sz:.*]]: index) -> memref> { +// CHECK: %[[casted:.*]] = memref.cast %[[t]] : memref to memref<10xf32, #[[$map]]> +// CHECK: %[[slice:.*]] = memref.subview %[[casted]][2] [%[[sz]]] [1] : memref<10xf32, #[[$map]]> to memref> +// CHECK: return %[[slice]] +func.func @cast_retains_buffer_layout( + %t: tensor + {bufferization.buffer_layout = affine_map<(d0) -> (d0 + 5)>}, + %sz: index) + -> (tensor<10xf32>, tensor) +{ + %casted = tensor.cast %t : tensor to tensor<10xf32> + %slice = tensor.extract_slice %casted[2][%sz][1] : tensor<10xf32> to tensor + + // Note: The %casted return type is folded away because both buffers are + // equivalent. Therefore, we currently loose some static type information + // in the caller. + return %casted, %slice : tensor<10xf32>, tensor +}