Index: mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp =================================================================== --- mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -1324,6 +1324,12 @@ if (!elementType.isIntOrFloat()) return failure(); + Type dstElementType = op.getType(); + if (auto vecType = dstElementType.dyn_cast()) + dstElementType = vecType.getElementType(); + if (elementType != dstElementType) + return failure(); + // Perform lhs + rhs transpositions to conform to matmul row-major semantics. // Bail out if the contraction cannot be put in this form. MLIRContext *ctx = op.getContext(); @@ -1416,11 +1422,29 @@ return builder.create(loc, v, perm); } + Value promote(Value v, Type dstElementType) { + Type elementType = v.getType(); + auto vecType = elementType.dyn_cast(); + if (vecType) + elementType = vecType.getElementType(); + if (elementType == dstElementType) + return v; + Type promotedType = dstElementType; + if (vecType) + promotedType = VectorType::get(vecType.getShape(), promotedType); + if (dstElementType.isa()) + return builder.create(loc, promotedType, v); + return builder.create(loc, promotedType, v); + } + Value outerProd(Value lhs, Value rhs, Value res, int reductionSize) { assert(reductionSize > 0); + Type resElementType = res.getType().cast().getElementType(); for (int64_t k = 0; k < reductionSize; ++k) { Value a = builder.create(loc, lhs, k); Value b = builder.create(loc, rhs, k); + a = promote(a, resElementType); + b = promote(b, resElementType); res = builder.create(loc, res.getType(), a, b, res, kind); } Index: mlir/test/Dialect/Vector/vector-contract-transforms.mlir =================================================================== --- mlir/test/Dialect/Vector/vector-contract-transforms.mlir +++ mlir/test/Dialect/Vector/vector-contract-transforms.mlir @@ -891,6 +891,25 @@ return %0 : vector<2x3xf32> } +// OUTERPRODUCT-LABEL: func @matmul_0 +// OUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf16>, +// OUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf16>, +// OUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32> +// OUTERPRODUCT: %[[At:.*]] = vector.transpose %[[A]], [1, 0] +// OUTERPRODUCT: %[[a0:.*]] = vector.extract %[[At]][0] : vector<1x2xf16> +// OUTERPRODUCT: %[[b0:.*]] = vector.extract %[[B]][0] : vector<1x3xf16> +// OUTERPRODUCT: %[[a1:.*]] = arith.extf %[[a0]] : vector<2xf16> to vector<2xf32> +// OUTERPRODUCT: %[[b1:.*]] = arith.extf %[[b0]] : vector<3xf16> to vector<3xf32> +// OUTERPRODUCT: %[[c0:.*]] = vector.outerproduct %[[a1]], %[[b1]], %[[C]] +// OUTERPRODUCT: return %[[c0]] : vector<2x3xf32> +func.func @matmul_0_mixed(%arg0: vector<2x1xf16>, %arg1: vector<1x3xf16>, %arg2: vector<2x3xf32>) +-> vector<2x3xf32> +{ + %0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2 + : vector<2x1xf16>, vector<1x3xf16> into vector<2x3xf32> + return %0 : vector<2x3xf32> +} + #matmat_accesses_1 = [ affine_map<(m, n, k) -> (m, k)>, affine_map<(m, n, k) -> (n, k)>,