diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -2512,6 +2512,29 @@ .getOperation(); } + // Take a value and widen to have the same element type as `ty`. + Value promote(RewriterBase &rewriter, Location loc, Value val, Type ty) { + const Type srcElementType = getElementTypeOrSelf(val.getType()); + const Type dstElementType = getElementTypeOrSelf(ty); + assert(isa(dstElementType) || isa(dstElementType)); + if (srcElementType == dstElementType) + return val; + + const int64_t srcWidth = srcElementType.getIntOrFloatBitWidth(); + const int64_t dstWidth = dstElementType.getIntOrFloatBitWidth(); + const Type dstType = + cast(val.getType()).cloneWith(std::nullopt, dstElementType); + + if (isa(dstElementType) && srcWidth < dstWidth) + return rewriter.create(loc, dstType, val); + + if (isa(dstElementType) && srcWidth < dstWidth) + return rewriter.create(loc, dstType, val); + + assert(false && "unhandled promotion case"); + return nullptr; + } + // Create a contraction: lhs{n, w, c} * rhs{c, f} -> res{n, w, f} Value conv1dSliceAsContraction(RewriterBase &rewriter, Location loc, Value lhs, Value rhs, Value res) { @@ -2519,6 +2542,8 @@ vector::IteratorType red = vector::IteratorType::reduction; AffineExpr n, w, f, c; bindDims(ctx, n, w, f, c); + lhs = promote(rewriter, loc, lhs, res.getType()); + rhs = promote(rewriter, loc, rhs, res.getType()); return rewriter.create( loc, lhs, rhs, res, /*indexingMaps=*/MapList{{n, w, c}, {c, f}, {n, w, f}}, @@ -2666,24 +2691,6 @@ .getOperation(); } - // Take a value of element type T and widen to the destination type. - Value promote(RewriterBase &rewriter, Location loc, Value val, Type ty) { - if (val.getType() == ty) - return val; - - const int64_t srcWidth = - getElementTypeOrSelf(val.getType()).getIntOrFloatBitWidth(); - const int64_t destWidth = getElementTypeOrSelf(ty).getIntOrFloatBitWidth(); - - if (getElementTypeOrSelf(ty).isa() && srcWidth < destWidth) - return rewriter.create(loc, ty, val); - - if (getElementTypeOrSelf(ty).isa() && srcWidth < destWidth) - return rewriter.create(loc, ty, val); - - return nullptr; - } - /// Lower lhs{n, w, c} * rhs{c} -> res{n, w, c} to MulAcc Value depthwiseConv1dSliceAsMulAcc(RewriterBase &rewriter, Location loc, Value lhs, Value rhs, Value res) { diff --git a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir --- a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir +++ b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir @@ -100,18 +100,22 @@ // CHECK-SAME: {offsets = [0, 1, 0], sizes = [4, 1, 8], strides = [1, 1, 1]} : vector<4x2x8xi32> to vector<4x1x8xi32> /// w == 0, kw == 0 +// CHECK: %[[EXT_LHS_0:.+]] = arith.extsi %[[V_INPUT_0]] : vector<4x1x3xi8> to vector<4x1x3xi32> +// CHECK: %[[EXT_RHS_0:.+]] = arith.extsi %[[V_FILTER]] : vector<3x8xi8> to vector<3x8xi32> // CHECK: %[[CONTRACT_0:.+]] = vector.contract { // CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]], // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"] -// CHECK-SAME: %[[V_INPUT_0]], %[[V_FILTER]], %[[V_OUTPUT_0]] -// CHECK-SAME: : vector<4x1x3xi8>, vector<3x8xi8> into vector<4x1x8xi32> +// CHECK-SAME: %[[EXT_LHS_0]], %[[EXT_RHS_0]], %[[V_OUTPUT_0]] +// CHECK-SAME: : vector<4x1x3xi32>, vector<3x8xi32> into vector<4x1x8xi32> /// w == 1, kw == 0 +// CHECK: %[[EXT_LHS_1:.+]] = arith.extsi %[[V_INPUT_1]] : vector<4x1x3xi8> to vector<4x1x3xi32> +// CHECK: %[[EXT_RHS_1:.+]] = arith.extsi %[[V_FILTER]] : vector<3x8xi8> to vector<3x8xi32> // CHECK: %[[CONTRACT_1:.+]] = vector.contract { // CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]], // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"] -// CHECK-SAME: %[[V_INPUT_1]], %[[V_FILTER]], %[[V_OUTPUT_1]] -// CHECK-SAME: : vector<4x1x3xi8>, vector<3x8xi8> into vector<4x1x8xi32> +// CHECK-SAME: %[[EXT_LHS_1]], %[[EXT_RHS_1]], %[[V_OUTPUT_1]] +// CHECK-SAME: : vector<4x1x3xi32>, vector<3x8xi32> into vector<4x1x8xi32> /// w == 0, kw == 0 // CHECK: %[[RES_0:.+]] = vector.insert_strided_slice %[[CONTRACT_0]], %[[V_OUTPUT_R]]