diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -1289,9 +1289,16 @@ Value m; if (acc) { Value e = rewriter.create(loc, rhsType, acc, pos); - m = rewriter.create(loc, b, op.rhs(), e); + if (eltType.isa()) + m = rewriter.create( + loc, rewriter.create(loc, b, op.rhs()), e); + else + m = rewriter.create(loc, b, op.rhs(), e); } else { - m = rewriter.create(loc, b, op.rhs()); + if (eltType.isa()) + m = rewriter.create(loc, b, op.rhs()); + else + m = rewriter.create(loc, b, op.rhs()); } result = rewriter.create(loc, resType, m, result, pos); } diff --git a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir @@ -293,6 +293,50 @@ return %0: vector<2x3xf32> } +// CHECK-LABEL: func @outerproduct_noacc_int +// CHECK-SAME: %[[A:.*0]]: vector<2xi32>, +// CHECK-SAME: %[[B:.*1]]: vector<3xi32> +// CHECK: %[[C0:.*]] = constant dense<0> : vector<2x3xi32> +// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2xi32> +// CHECK: %[[T1:.*]] = splat %[[T0]] : vector<3xi32> +// CHECK: %[[T2:.*]] = muli %[[T1]], %[[B]] : vector<3xi32> +// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C0]] [0] : vector<3xi32> into vector<2x3xi32> +// CHECK: %[[T4:.*]] = vector.extract %[[A]][1] : vector<2xi32> +// CHECK: %[[T5:.*]] = splat %[[T4]] : vector<3xi32> +// CHECK: %[[T6:.*]] = muli %[[T5]], %[[B]] : vector<3xi32> +// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xi32> into vector<2x3xi32> +// CHECK: return %[[T7]] : vector<2x3xi32> +func @outerproduct_noacc_int(%arg0: vector<2xi32>, + %arg1: vector<3xi32>) -> vector<2x3xi32> { + %0 = vector.outerproduct %arg0, %arg1 : vector<2xi32>, vector<3xi32> + return %0: vector<2x3xi32> +} + +// CHECK-LABEL: func @outerproduct_acc_int +// CHECK-SAME: %[[A:.*0]]: vector<2xi32>, +// CHECK-SAME: %[[B:.*1]]: vector<3xi32>, +// CHECK-SAME: %[[C:.*2]]: vector<2x3xi32> +// CHECK: %[[C0:.*]] = constant dense<0> : vector<2x3xi32> +// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2xi32> +// CHECK: %[[T1:.*]] = splat %[[T0]] : vector<3xi32> +// CHECK: %[[T2:.*]] = vector.extract %[[C]][0] : vector<2x3xi32> +// CHECK: %[[T3:.*]] = muli %[[T1]], %[[B]] : vector<3xi32> +// CHECK: %[[T4:.*]] = addi %[[T3]], %[[T2]] : vector<3xi32> +// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[C0]] [0] : vector<3xi32> into vector<2x3xi32> +// CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : vector<2xi32> +// CHECK: %[[T7:.*]] = splat %[[T6]] : vector<3xi32> +// CHECK: %[[T8:.*]] = vector.extract %[[C]][1] : vector<2x3xi32> +// CHECK: %[[T9:.*]] = muli %[[T7]], %[[B]] : vector<3xi32> +// CHECK: %[[T10:.*]] = addi %[[T9]], %[[T8]] : vector<3xi32> +// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T5]] [1] : vector<3xi32> into vector<2x3xi32> +// CHECK: return %[[T11]] : vector<2x3xi32> +func @outerproduct_acc_int(%arg0: vector<2xi32>, + %arg1: vector<3xi32>, + %arg2: vector<2x3xi32>) -> vector<2x3xi32> { + %0 = vector.outerproduct %arg0, %arg1, %arg2 : vector<2xi32>, vector<3xi32> + return %0: vector<2x3xi32> +} + // CHECK-LABEL: func @transpose23 // CHECK-SAME: %[[A:.*]]: vector<2x3xf32> // CHECK: %[[Z:.*]] = constant dense<0.000000e+00> : vector<3x2xf32>