diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -1477,7 +1477,7 @@ {nSize, // iw = ow * sw + kw * dw - 1 // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14) - // Perform the proper inclusive -> exclusive -> inclusive + // Perform the proper inclusive -> exclusive -> inclusive. ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1, cSize}, lhsEltType); @@ -1557,9 +1557,8 @@ } // Create a contraction: lhs{n, w, c} * rhs{c, f} -> res{n, w, f} - vector::ContractionOp conv1dSliceAsContraction(OpBuilder &b, Location loc, - Value lhs, Value rhs, - Value res) { + Value conv1dSliceAsContraction(OpBuilder &b, Location loc, Value lhs, + Value rhs, Value res) { StringRef par = Par().strRef, red = Red().strRef; AffineExpr n, w, f, c; bindDims(ctx, n, w, f, c); @@ -1597,7 +1596,10 @@ Type rhsEltType = rhsShapedType.getElementType(); Type resEltType = resShapedType.getElementType(); VectorType lhsType = VectorType::get( - {nSize, (wSize - 1) * strideW + 1 + (kwSize - 1) * dilationW + 1, + {nSize, + // iw = ow * sw + kw * dw - 1 + // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14) + ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1, cSize}, lhsEltType); VectorType rhsType = VectorType::get({kwSize, cSize}, rhsEltType); @@ -1651,7 +1653,7 @@ // Compute contraction: O{n, w, c} += I{n, sw * w + dw * kw, c} * F{c} for (int64_t kw = 0; kw < kwSize; ++kw) { for (int64_t w = 0; w < wSize; w += wSizeStep) { - resVals[w] = dilatedConv1dSliceAsContraction( + resVals[w] = dilatedConv1dSliceAsFma( builder, loc, lhsVals[linearIndex(kw, w)], rhsVals[kw], resVals[w]); } } @@ -1675,17 +1677,11 @@ .getOperation(); } - // Create a contraction: lhs{n, w, c} * rhs{c} -> res{n, w, c} - vector::ContractionOp dilatedConv1dSliceAsContraction(OpBuilder &b, - Location loc, Value lhs, - Value rhs, Value res) { - StringRef par = Par().strRef, red = Red().strRef; - AffineExpr n, w, c; - bindDims(ctx, n, w, c); - return builder.create( - loc, lhs, rhs, res, - /*indexingMaps=*/MapList{{n, w, c}, {c}, {n, w, c}}, - /*iteratorTypes=*/ArrayRef{par, par, red}); + /// Lower lhs{n, w, c} * rhs{c} -> res{n, w, c} to fma. + Value dilatedConv1dSliceAsFma(OpBuilder &b, Location loc, Value lhs, + Value rhs, Value res) { + Value bcast = builder.create(loc, res.getType(), rhs); + return b.create(loc, lhs, bcast, res); } /// Entry point that transposes into the common form: diff --git a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir --- a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir +++ b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir @@ -200,9 +200,6 @@ return } -// CHECK: #[[INPUT_MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> -// CHECK: #[[FILTER_MAP:.+]] = affine_map<(d0, d1, d2) -> (d2)> - // CHECK: func @depthwise_conv1d_nwc_wc_3x5x4_memref // CHECK-SAME: (%[[INPUT:[0-9a-z]+]]: memref<3x5x4xf32>, %[[FILTER:[0-9a-z]+]]: memref<2x4xf32>, %[[OUTPUT:[0-9a-z]+]]: memref<3x2x4xf32>) @@ -217,24 +214,19 @@ /// w == 0, kw == 0 // CHECK: %[[V_FILTER_0:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<2x4xf32> // CHECK: %[[V_INPUT_0:.+]] = vector.extract_strided_slice %[[V_INPUT_R]] -// CHECK-SAME: {offsets = [0, 0, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x5x4xf32> to vector<3x2x4xf32> +// CHECK-SAME: {offsets = [0, 0, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x4xf32> to vector<3x2x4xf32> /// w == 0, kw == 1 // CHECK: %[[V_FILTER_1:.+]] = vector.extract %[[V_FILTER_R]][1] : vector<2x4xf32> // CHECK: %[[V_INPUT_1:.+]] = vector.extract_strided_slice %[[V_INPUT_R]] -// CHECK-SAME: {offsets = [0, 2, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x5x4xf32> to vector<3x2x4xf32> +// CHECK-SAME: {offsets = [0, 2, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x4xf32> to vector<3x2x4xf32> -/// w == 0, kw == 0 -// CHECK: %[[CONTRACT_0:.+]] = vector.contract { -// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[INPUT_MAP]]], -// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]} -// CHECK-SAME: %[[V_INPUT_0]], %[[V_FILTER_0]], %[[V_OUTPUT_R]] -// CHECK-SAME: : vector<3x2x4xf32>, vector<4xf32> into vector<3x2x4xf32> -/// w == 0, kw == 1 -// CHECK: %[[CONTRACT_1:.+]] = vector.contract { -// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[INPUT_MAP]]], -// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]} -// CHECK-SAME: %[[V_INPUT_1]], %[[V_FILTER_1]], %[[CONTRACT_0]] -// CHECK-SAME: : vector<3x2x4xf32>, vector<4xf32> into vector<3x2x4xf32> +/// w == 0, kw = 0 +// CHECK: %[[B_FILTER_0:.*]] = vector.broadcast %[[V_FILTER_0]] : vector<4xf32> to vector<3x2x4xf32> +// CHECK: %[[FMA_0:.*]] = vector.fma %[[V_INPUT_0]], %[[B_FILTER_0]], %[[V_OUTPUT_R]] : vector<3x2x4xf32> + +/// w == 0, kw = 1 +// CHECK: %[[B_FILTER_1:.*]] = vector.broadcast %[[V_FILTER_1]] : vector<4xf32> to vector<3x2x4xf32> +// CHECK: %[[FMA_1:.*]] = vector.fma %[[V_INPUT_1]], %[[B_FILTER_1]], %[[FMA_0]] : vector<3x2x4xf32> // Write the result back in one shot. -// CHECK: vector.transfer_write %[[CONTRACT_1]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]] +// CHECK: vector.transfer_write %[[FMA_1]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]