diff --git a/mlir/lib/Transforms/BufferPlacement.cpp b/mlir/lib/Transforms/BufferPlacement.cpp --- a/mlir/lib/Transforms/BufferPlacement.cpp +++ b/mlir/lib/Transforms/BufferPlacement.cpp @@ -700,15 +700,22 @@ BufferAssignmentTypeConverter::BufferAssignmentTypeConverter() { // Keep all types unchanged. addConversion([](Type type) { return type; }); - // A type conversion that converts ranked-tensor type to memref type. + // Convert RankedTensorType to MemRefType. addConversion([](RankedTensorType type) { return (Type)MemRefType::get(type.getShape(), type.getElementType()); }); + // Convert UnrankedTensorType to UnrankedMemRefType. + addConversion([](UnrankedTensorType type) { + return (Type)UnrankedMemRefType::get(type.getElementType(), 0); + }); } /// Checks if `type` has been converted from non-memref type to memref. bool BufferAssignmentTypeConverter::isConvertedMemref(Type type, Type before) { - return type.isa() && !before.isa(); + auto isMemRefType = [](Type type) -> bool { + return type.isa() || type.isa(); + }; + return isMemRefType(type) && !isMemRefType(before); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Transforms/buffer-placement-preparation-allowed-memref-results.mlir b/mlir/test/Transforms/buffer-placement-preparation-allowed-memref-results.mlir --- a/mlir/test/Transforms/buffer-placement-preparation-allowed-memref-results.mlir +++ b/mlir/test/Transforms/buffer-placement-preparation-allowed-memref-results.mlir @@ -64,6 +64,15 @@ // ----- +// CHECK-LABEL: func @func_with_unranked_arg_and_result +func @func_with_unranked_arg_and_result(%arg0: tensor<*xf32>) -> tensor<*xf32> { + return %arg0 : tensor<*xf32> +} +// CHECK-SAME: ([[ARG:%.*]]: memref<*xf32>) -> memref<*xf32> +// CHECK-NEXT: return [[ARG]] : memref<*xf32> + +// ----- + // CHECK-LABEL: func @func_and_block_signature_conversion func @func_and_block_signature_conversion(%arg0 : tensor<2xf32>, %cond : i1, %arg1: tensor<4x4xf32>) -> tensor<4x4xf32>{ cond_br %cond, ^bb1, ^bb2 diff --git a/mlir/test/Transforms/buffer-placement-preparation.mlir b/mlir/test/Transforms/buffer-placement-preparation.mlir --- a/mlir/test/Transforms/buffer-placement-preparation.mlir +++ b/mlir/test/Transforms/buffer-placement-preparation.mlir @@ -284,3 +284,9 @@ // CHECK: %[[Y1:.*]] = call @callee(%[[X0]], %[[Y0]]) // CHECK: linalg.copy(%[[Y0]], %[[CALLER_RESULT]]) // CHECK: return + +// CHECK-LABEL: func @func_with_unranked_arg +func @func_with_unranked_arg(%arg0: tensor<*xf32>) { + return +} +// CHECK-SAME: ([[ARG:%.*]]: memref<*xf32>)