diff --git a/mlir/lib/Transforms/Bufferize.cpp b/mlir/lib/Transforms/Bufferize.cpp --- a/mlir/lib/Transforms/Bufferize.cpp +++ b/mlir/lib/Transforms/Bufferize.cpp @@ -27,13 +27,13 @@ addConversion([](UnrankedTensorType type) -> Type { return UnrankedMemRefType::get(type.getElementType(), 0); }); - addSourceMaterialization([](OpBuilder &builder, RankedTensorType type, + addSourceMaterialization([](OpBuilder &builder, TensorType type, ValueRange inputs, Location loc) -> Value { assert(inputs.size() == 1); assert(inputs[0].getType().isa()); return builder.create(loc, type, inputs[0]); }); - addTargetMaterialization([](OpBuilder &builder, MemRefType type, + addTargetMaterialization([](OpBuilder &builder, BaseMemRefType type, ValueRange inputs, Location loc) -> Value { assert(inputs.size() == 1); assert(inputs[0].getType().isa()); diff --git a/mlir/test/Dialect/Standard/bufferize.mlir b/mlir/test/Dialect/Standard/bufferize.mlir --- a/mlir/test/Dialect/Standard/bufferize.mlir +++ b/mlir/test/Dialect/Standard/bufferize.mlir @@ -86,6 +86,28 @@ return %0 : tensor<2xindex> } +// CHECK-LABEL: func @tensor_cast_from_unranked( +// CHECK-SAME: %[[TENSOR:.*]]: tensor<*xf32>) -> tensor<2xf32> { +// CHECK: %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]] : memref<*xf32> +// CHECK: %[[CASTED_MEMREF:.*]] = memref_cast %[[MEMREF]] : memref<*xf32> to memref<2xf32> +// CHECK: %[[RET:.*]] = tensor_load %[[CASTED_MEMREF]] : memref<2xf32> +// CHECK: return %[[RET]] : tensor<2xf32> +func @tensor_cast_from_unranked(%arg0: tensor<*xf32>) -> tensor<2xf32> { + %0 = tensor_cast %arg0 : tensor<*xf32> to tensor<2xf32> + return %0 : tensor<2xf32> +} + +// CHECK-LABEL: func @tensor_cast_to_unranked( +// CHECK-SAME: %[[TENSOR:.*]]: tensor<2xf32>) -> tensor<*xf32> { +// CHECK: %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]] : memref<2xf32> +// CHECK: %[[CASTED_MEMREF:.*]] = memref_cast %[[MEMREF]] : memref<2xf32> to memref<*xf32> +// CHECK: %[[RET:.*]] = tensor_load %[[CASTED_MEMREF]] : memref<*xf32> +// CHECK: return %[[RET]] : tensor<*xf32> +func @tensor_cast_to_unranked(%arg0: tensor<2xf32>) -> tensor<*xf32> { + %0 = tensor_cast %arg0 : tensor<2xf32> to tensor<*xf32> + return %0 : tensor<*xf32> +} + // CHECK-LABEL: func @tensor_from_elements( // CHECK-SAME: %[[ELEM0:.*]]: index, // CHECK-SAME: %[[ELEM1:.*]]: index) -> tensor<2xindex> {