diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -2623,10 +2623,16 @@ const Type dstType = cast(val.getType()).cloneWith(std::nullopt, dstElementType); - if (isa(dstElementType) && srcWidth < dstWidth) + if (isa(srcElementType) && isa(dstElementType)) { + return rewriter.create(loc, dstType, val); + } + + if (isa(srcElementType) && isa(dstElementType) && + srcWidth < dstWidth) return rewriter.create(loc, dstType, val); - if (isa(dstElementType) && srcWidth < dstWidth) + if (isa(srcElementType) && isa(dstElementType) && + srcWidth < dstWidth) return rewriter.create(loc, dstType, val); assert(false && "unhandled promotion case"); diff --git a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir --- a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir +++ b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir @@ -631,6 +631,31 @@ // ----- +func.func @conv_1d_nwc_wcf_mixed_int_fp_memref(%input: memref<1x2x3xi8>, %filter: memref<1x3x2xi8>, %output: memref<1x2x2xf32>) { + linalg.conv_1d_nwc_wcf + {dilations = dense<1> : vector<1xi64>, strides = dense<1> : vector<1xi64>} + ins(%input, %filter : memref<1x2x3xi8>, memref<1x3x2xi8>) + outs(%output : memref<1x2x2xf32>) + return +} + + +// CHECK-LABEL: func @conv_1d_nwc_wcf_mixed_int_fp_memref +// CHECK-SAME: (%[[INPUT:[0-9a-z]+]]: memref<1x2x3xi8>, %[[FILTER:[0-9a-z]+]]: memref<1x3x2xi8>, %[[OUTPUT:[0-9a-z]+]]: memref<1x2x2xf32>) +// CHECK-DAG: %[[CST:.+]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: %[[I0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : i8 +// CHECK: %[[READ0:.+]] = vector.transfer_read %arg0[%[[I0]], %[[I0]], %[[I0]]], %[[C0]] +// CHECK: %[[READ1:.+]] = vector.transfer_read %arg1[%[[I0]], %[[I0]], %[[I0]]], %[[C0]] +// CHECK: %[[READ2:.+]] = vector.transfer_read %arg2[%[[I0]], %[[I0]], %[[I0]]], %[[CST]] +// CHECK: %[[EXT:.+]] = vector.extract %[[READ1]][0] : vector<1x3x2xi8> +// CHECK: %[[CAST0:.+]] = arith.sitofp %[[READ0]] : vector<1x2x3xi8> to vector<1x2x3xf32> +// CHECK: %[[CAST1:.+]] = arith.sitofp %[[EXT]] : vector<3x2xi8> to vector<3x2xf32> +// CHECK: %[[CONTRACT:.+]] = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "reduction"], kind = #vector.kind} %[[CAST0]], %[[CAST1]], %[[READ2]] +// CHECK: vector.transfer_write %[[CONTRACT]], %arg2[%[[I0]], %[[I0]], %[[I0]]] + +// ----- + func.func @pooling_nwc_sum_memref_1_2_1_3(%input: memref<4x4x3xf32>, %filter: memref<1xf32>, %output: memref<4x2x3xf32>) { linalg.pooling_nwc_sum {dilations = dense<1> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>}