diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -1415,6 +1415,22 @@ //===----------------------------------------------------------------------===// // Convolution vectorization patterns //===----------------------------------------------------------------------===// + +template +void bindShapeDims(ShapedType shapedType) {} + +template +void bindShapeDims(ShapedType shapedType, IntTy &val, IntTy2 &...vals) { + val = shapedType.getShape()[N]; + bindShapeDims(shapedType, vals...); +} + +/// Bind a pack of int& to the leading dimensions of shapedType.getShape(). +template +void bindShapeDims(ShapedType shapedType, IntTy &...vals) { + bindShapeDims<0>(shapedType, vals...); +} + namespace { /// Generate a vector implementation for either: /// ``` @@ -1482,11 +1498,11 @@ if (!valid) return failure(); - int nSize = lhsShapedType.getShape()[0]; - int wSize = resShapedType.getShape()[1]; - int cSize = lhsShapedType.getShape()[2]; - int kwSize = rhsShapedType.getShape()[0]; - int fSize = rhsShapedType.getShape()[2]; + int64_t nSize, wSize, cSize, kwSize, fSize; + // kernel{kw, c, f} + bindShapeDims(rhsShapedType, kwSize, cSize, fSize); + // out{n, w, f} + bindShapeDims(resShapedType, nSize, wSize); vector::TransferWriteOp write; Value zero = builder.create(loc, 0); @@ -1526,31 +1542,29 @@ //===------------------------------------------------------------------===// // Unroll along kw and read slices of lhs and rhs. SmallVector lhsVals, rhsVals, resVals; + // Extract lhs slice of size {n, wSizeStep, c} @ [0, sw * w + dw * kw, 0]. for (int64_t kw = 0; kw < kwSize; ++kw) { - // Extract rhs slice of size {c, f} @ [kw]. - rhsVals.push_back(builder.create( - loc, rhs, /*offsets=*/ArrayRef{kw})); - for (int64_t w = 0; w < wSize; w += wSizeStep) { - // Extract lhs slice of size {n, wSizeStep, c} - // @ [0, sw * w + dw * kw, 0]. lhsVals.push_back(builder.create( loc, lhs, /*offsets=*/ArrayRef{0, w * strideW + kw * dilationW, 0}, /*sizes=*/ArrayRef{nSize, wSizeStep, cSize}, /*strides=*/ArrayRef{1, 1, 1})); - - // This does not depend on kw. - if (kw == 0) { - // Extract res slice: {n, wSizeStep, f} @ [0, w, 0]. - resVals.push_back(builder.create( - loc, res, - /*offsets=*/ArrayRef{0, w, 0}, - /*sizes=*/ArrayRef{nSize, wSizeStep, fSize}, - /*strides=*/ArrayRef{1, 1, 1})); - } } } + // Extract rhs slice of size {c, f} @ [kw]. + for (int64_t kw = 0; kw < kwSize; ++kw) { + rhsVals.push_back(builder.create( + loc, rhs, /*offsets=*/ArrayRef{kw})); + } + // Extract res slice: {n, wSizeStep, f} @ [0, w, 0]. + for (int64_t w = 0; w < wSize; w += wSizeStep) { + resVals.push_back(builder.create( + loc, res, + /*offsets=*/ArrayRef{0, w, 0}, + /*sizes=*/ArrayRef{nSize, wSizeStep, fSize}, + /*strides=*/ArrayRef{1, 1, 1})); + } auto linearIndex = [&](int64_t kw, int64_t w) { return kw * (wSize / wSizeStep) + w; @@ -1604,14 +1618,15 @@ /// kw is always unrolled. /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is /// > 1. - FailureOr dilatedConv() { + FailureOr depthwiseConv() { if (!valid) return failure(); - int nSize = lhsShapedType.getShape()[0]; - int wSize = resShapedType.getShape()[1]; - int cSize = lhsShapedType.getShape()[2]; - int kwSize = rhsShapedType.getShape()[0]; + int64_t nSize, wSize, cSize, kwSize; + // kernel{kw, c} + bindShapeDims(rhsShapedType, kwSize, cSize); + // out{n, w, c} + bindShapeDims(resShapedType, nSize, wSize); vector::TransferWriteOp write; Value zero = builder.create(loc, 0); @@ -1650,31 +1665,30 @@ //===------------------------------------------------------------------===// // Unroll along kw and read slices of lhs and rhs. SmallVector lhsVals, rhsVals, resVals; + // Extract lhs slice of size {n, wSizeStep, c} + // @ [0, sw * w + dw * kw, 0]. for (int64_t kw = 0; kw < kwSize; ++kw) { - // Extract rhs slice of size {c} @ [kw]. - rhsVals.push_back(builder.create( - loc, rhs, /*offsets=*/ArrayRef{kw})); - for (int64_t w = 0; w < wSize; w += wSizeStep) { - // Extract lhs slice of size {n, wSizeStep, c} - // @ [0, sw * w + dw * kw, 0]. lhsVals.push_back(builder.create( loc, lhs, /*offsets=*/ArrayRef{0, w * strideW + kw * dilationW, 0}, /*sizes=*/ArrayRef{nSize, wSizeStep, cSize}, /*strides=*/ArrayRef{1, 1, 1})); - - // This does not depend on kw. - if (kw == 0) { - // Extract res slice: {n, wSizeStep, c} @ [0, w, 0]. - resVals.push_back(builder.create( - loc, res, - /*offsets=*/ArrayRef{0, w, 0}, - /*sizes=*/ArrayRef{nSize, wSizeStep, cSize}, - /*strides=*/ArrayRef{1, 1, 1})); - } } } + // Extract rhs slice of size {c} @ [kw]. + for (int64_t kw = 0; kw < kwSize; ++kw) { + rhsVals.push_back(builder.create( + loc, rhs, /*offsets=*/ArrayRef{kw})); + } + // Extract res slice: {n, wSizeStep, c} @ [0, w, 0]. + for (int64_t w = 0; w < wSize; w += wSizeStep) { + resVals.push_back(builder.create( + loc, res, + /*offsets=*/ArrayRef{0, w, 0}, + /*sizes=*/ArrayRef{nSize, wSizeStep, cSize}, + /*strides=*/ArrayRef{1, 1, 1})); + } auto linearIndex = [&](int64_t kw, int64_t w) { return kw * (wSize / wSizeStep) + w; @@ -1683,7 +1697,7 @@ // Compute contraction: O{n, w, c} += I{n, sw * w + dw * kw, c} * F{c} for (int64_t kw = 0; kw < kwSize; ++kw) { for (int64_t w = 0; w < wSize; w += wSizeStep) { - resVals[w] = dilatedConv1dSliceAsFma( + resVals[w] = depthwiseConv1dSliceAsFma( builder, loc, lhsVals[linearIndex(kw, w)], rhsVals[kw], resVals[w]); } } @@ -1708,8 +1722,8 @@ } /// Lower lhs{n, w, c} * rhs{c} -> res{n, w, c} to fma. - Value dilatedConv1dSliceAsFma(OpBuilder &b, Location loc, Value lhs, - Value rhs, Value res) { + Value depthwiseConv1dSliceAsFma(OpBuilder &b, Location loc, Value lhs, + Value rhs, Value res) { Value bcast = builder.create(loc, res.getType(), rhs); return b.create(loc, lhs, bcast, res); } @@ -1742,7 +1756,7 @@ if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c}, /*rhsIndex*/ {kw, c}, /*resIndex*/ {n, w, c}})) - return dilatedConv(); + return depthwiseConv(); return failure(); } diff --git a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir --- a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir +++ b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir @@ -23,15 +23,15 @@ // CHECK-DAG: %[[V_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]] // CHECK-DAG: %[[V_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]] -// CHECK: %[[V_FILTER:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<1x3x8xf32> -/// w == 0, kw == 0 // CHECK: %[[V_INPUT_0:.+]] = vector.extract_strided_slice %[[V_INPUT_R]] // CHECK-SAME: {offsets = [0, 0, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x4x3xf32> to vector<4x1x3xf32> -// CHECK: %[[V_OUTPUT_0:.+]] = vector.extract_strided_slice %[[V_OUTPUT_R]] -// CHECK-SAME: {offsets = [0, 0, 0], sizes = [4, 1, 8], strides = [1, 1, 1]} : vector<4x2x8xf32> to vector<4x1x8xf32> -/// w == 1, kw == 0 // CHECK: %[[V_INPUT_1:.+]] = vector.extract_strided_slice %[[V_INPUT_R]] // CHECK-SAME: {offsets = [0, 3, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x4x3xf32> to vector<4x1x3xf32> + +// CHECK: %[[V_FILTER:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<1x3x8xf32> + +// CHECK: %[[V_OUTPUT_0:.+]] = vector.extract_strided_slice %[[V_OUTPUT_R]] +// CHECK-SAME: {offsets = [0, 0, 0], sizes = [4, 1, 8], strides = [1, 1, 1]} : vector<4x2x8xf32> to vector<4x1x8xf32> // CHECK: %[[V_OUTPUT_1:.+]] = vector.extract_strided_slice %[[V_OUTPUT_R]] // CHECK-SAME: {offsets = [0, 1, 0], sizes = [4, 1, 8], strides = [1, 1, 1]} : vector<4x2x8xf32> to vector<4x1x8xf32> @@ -84,27 +84,23 @@ // CHECK-DAG: %[[V_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]] // CHECK-DAG: %[[V_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]] - -/// w == 0, kw == 0 -// CHECK: %[[V_FILTER_0:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<2x3x8xf32> // CHECK: %[[V_INPUT_0:.+]] = vector.extract_strided_slice %[[V_INPUT_R]] // CHECK-SAME: {offsets = [0, 0, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x6x3xf32> to vector<4x1x3xf32> -// CHECK: %[[V_OUTPUT_0:.+]] = vector.extract_strided_slice %[[V_OUTPUT_R]] -// CHECK-SAME: {offsets = [0, 0, 0], sizes = [4, 1, 8], strides = [1, 1, 1]} : vector<4x2x8xf32> to vector<4x1x8xf32> -/// w == 1, kw == 0 // CHECK: %[[V_INPUT_1:.+]] = vector.extract_strided_slice %[[V_INPUT_R]] // CHECK-SAME: {offsets = [0, 3, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x6x3xf32> to vector<4x1x3xf32> -// CHECK: %[[V_OUTPUT_1:.+]] = vector.extract_strided_slice %[[V_OUTPUT_R]] -// CHECK-SAME: {offsets = [0, 1, 0], sizes = [4, 1, 8], strides = [1, 1, 1]} : vector<4x2x8xf32> to vector<4x1x8xf32> - -/// w == 0, kw == 1 -// CHECK: %[[V_FILTER_1:.+]] = vector.extract %[[V_FILTER_R]][1] : vector<2x3x8xf32> // CHECK: %[[V_INPUT_2:.+]] = vector.extract_strided_slice %[[V_INPUT_R]] // CHECK-SAME: {offsets = [0, 2, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x6x3xf32> to vector<4x1x3xf32> -/// w == 1, kw == 0 // CHECK: %[[V_INPUT_3:.+]] = vector.extract_strided_slice %[[V_INPUT_R]] // CHECK-SAME: {offsets = [0, 5, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x6x3xf32> to vector<4x1x3xf32> +// CHECK: %[[V_FILTER_0:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<2x3x8xf32> +// CHECK: %[[V_FILTER_1:.+]] = vector.extract %[[V_FILTER_R]][1] : vector<2x3x8xf32> + +// CHECK: %[[V_OUTPUT_0:.+]] = vector.extract_strided_slice %[[V_OUTPUT_R]] +// CHECK-SAME: {offsets = [0, 0, 0], sizes = [4, 1, 8], strides = [1, 1, 1]} : vector<4x2x8xf32> to vector<4x1x8xf32> +// CHECK: %[[V_OUTPUT_1:.+]] = vector.extract_strided_slice %[[V_OUTPUT_R]] +// CHECK-SAME: {offsets = [0, 1, 0], sizes = [4, 1, 8], strides = [1, 1, 1]} : vector<4x2x8xf32> to vector<4x1x8xf32> + /// w == 0, kw == 0 // CHECK: %[[CONTRACT_0:.+]] = vector.contract { // CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]], @@ -165,15 +161,14 @@ // CHECK-DAG: %[[V_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]] // CHECK-DAG: %[[V_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]] -/// w == 0, kw == 0 -// CHECK: %[[V_FILTER_0:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<2x3x8xf32> // CHECK: %[[V_INPUT_0:.+]] = vector.extract_strided_slice %[[V_INPUT_R]] // CHECK-SAME: {offsets = [0, 0, 0], sizes = [4, 2, 3], strides = [1, 1, 1]} : vector<4x4x3xf32> to vector<4x2x3xf32> -/// w == 0, kw == 1 -// CHECK: %[[V_FILTER_1:.+]] = vector.extract %[[V_FILTER_R]][1] : vector<2x3x8xf32> // CHECK: %[[V_INPUT_1:.+]] = vector.extract_strided_slice %[[V_INPUT_R]] // CHECK-SAME: {offsets = [0, 2, 0], sizes = [4, 2, 3], strides = [1, 1, 1]} : vector<4x4x3xf32> to vector<4x2x3xf32> +// CHECK: %[[V_FILTER_0:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<2x3x8xf32> +// CHECK: %[[V_FILTER_1:.+]] = vector.extract %[[V_FILTER_R]][1] : vector<2x3x8xf32> + /// w == 0, kw == 0 // CHECK: %[[CONTRACT_0:.+]] = vector.contract { // CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]], @@ -211,15 +206,14 @@ // CHECK: %[[V_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]]] // CHECK: %[[V_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]] -/// w == 0, kw == 0 -// CHECK: %[[V_FILTER_0:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<2x4xf32> // CHECK: %[[V_INPUT_0:.+]] = vector.extract_strided_slice %[[V_INPUT_R]] // CHECK-SAME: {offsets = [0, 0, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x4xf32> to vector<3x2x4xf32> -/// w == 0, kw == 1 -// CHECK: %[[V_FILTER_1:.+]] = vector.extract %[[V_FILTER_R]][1] : vector<2x4xf32> // CHECK: %[[V_INPUT_1:.+]] = vector.extract_strided_slice %[[V_INPUT_R]] // CHECK-SAME: {offsets = [0, 2, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x4xf32> to vector<3x2x4xf32> +// CHECK: %[[V_FILTER_0:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<2x4xf32> +// CHECK: %[[V_FILTER_1:.+]] = vector.extract %[[V_FILTER_R]][1] : vector<2x4xf32> + /// w == 0, kw = 0 // CHECK: %[[B_FILTER_0:.*]] = vector.broadcast %[[V_FILTER_0]] : vector<4xf32> to vector<3x2x4xf32> // CHECK: %[[FMA_0:.*]] = vector.fma %[[V_INPUT_0]], %[[B_FILTER_0]], %[[V_OUTPUT_R]] : vector<3x2x4xf32>