diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -2596,9 +2596,15 @@ /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}} /// ``` /// kw is always unrolled. + /// + /// Set /p flattenChannelDim to true to flatten the channel dimension too. + /// This leads to better vectorisation when the number of channels is low + /// relative to native vector sizes (e.g. 1 vs 4). /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is /// > 1. - FailureOr depthwiseConv() { + FailureOr depthwiseConv( + bool flattenChannelDim = true /* TODO: Change the default to false in the final version */ + ) { if (!valid) return rewriter.notifyMatchFailure(op, "unvectorizable depthwise conv"); @@ -2619,41 +2625,96 @@ Type lhsEltType = lhsShapedType.getElementType(); Type rhsEltType = rhsShapedType.getElementType(); Type resEltType = resShapedType.getElementType(); - VectorType lhsType = VectorType::get( - {nSize, - // iw = ow * sw + kw * dw - 1 - // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14) - ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1, - cSize}, - lhsEltType); VectorType rhsType = VectorType::get({kwSize, cSize}, rhsEltType); - VectorType resType = VectorType::get({nSize, wSize, cSize}, resEltType); + VectorType lhsType; + if (!flattenChannelDim) { + lhsType = VectorType::get( + {nSize, + // iw = ow * sw + kw * dw - 1 + // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14) + ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1, + cSize}, + lhsEltType); + } else { + lhsType = VectorType::get( + {nSize, + // iw = (ow * sw + kw * dw - 1) * c + // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14) + (((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1) * + cSize}, + lhsEltType); + } - // Read lhs slice of size {n, w * strideW + kw * dilationW, c} @ [0, 0, - // 0]. - Value lhs = rewriter.create( - loc, lhsType, lhsShaped, ValueRange{zero, zero, zero}); + VectorType resType; + if (!flattenChannelDim) { + resType = VectorType::get({nSize, wSize, cSize}, resEltType); + } else { + resType = VectorType::get({nSize, wSize * cSize}, resEltType); + } + + Value res, lhs, lhsFlat, resFlat; // Read rhs slice of size {kw, c} @ [0, 0]. Value rhs = rewriter.create(loc, rhsType, rhsShaped, ValueRange{zero, zero}); - // Read res slice of size {n, w, c} @ [0, 0, 0]. - Value res = rewriter.create( - loc, resType, resShaped, ValueRange{zero, zero, zero}); + + SmallVector reassociation; + if (!flattenChannelDim) { + // Read lhs slice of size {n, w * strideW + kw * dilationW, c} @ [0, 0, + // 0]. + lhs = rewriter.create( + loc, lhsType, lhsShaped, ValueRange{zero, zero, zero}); + // Read res slice of size {n, w, c} @ [0, 0, 0]. + res = rewriter.create( + loc, resType, resShaped, ValueRange{zero, zero, zero}); + } else { + reassociation = {{0}, {1, 2}}; + + // Flatten w and c dimensions + lhsFlat = rewriter.create( + loc, RankedTensorType::get(lhsType.getShape(), lhsEltType), lhsShaped, + reassociation); + resFlat = rewriter.create( + loc, RankedTensorType::get(resType.getShape(), resEltType), resShaped, + reassociation); + + // Read lhs slice of size {n, (w * strideW + kw * dilationW) * c} @ [0, + // 0]. + lhs = rewriter.create(loc, lhsType, lhsFlat, + ValueRange{zero, zero}); + // Read res slice of size {n, w * c} @ [0, 0]. + res = rewriter.create(loc, resType, resFlat, + ValueRange{zero, zero}); + } //===------------------------------------------------------------------===// // Begin vector-only rewrite part //===------------------------------------------------------------------===// // Unroll along kw and read slices of lhs and rhs. SmallVector lhsVals, rhsVals, resVals; - // Extract lhs slice of size {n, wSizeStep, c} - // @ [0, sw * w + dw * kw, 0]. - for (int64_t kw = 0; kw < kwSize; ++kw) { - for (int64_t w = 0; w < wSize; w += wSizeStep) { - lhsVals.push_back(rewriter.create( - loc, lhs, - /*offsets=*/ArrayRef{0, w * strideW + kw * dilationW, 0}, - /*sizes=*/ArrayRef{nSize, wSizeStep, cSize}, - /*strides=*/ArrayRef{1, 1, 1})); + if (!flattenChannelDim) { + // Extract lhs slice of size {n, wSizeStep, c} + // @ [0, sw * w + dw * kw, 0]. + for (int64_t kw = 0; kw < kwSize; ++kw) { + for (int64_t w = 0; w < wSize; w += wSizeStep) { + lhsVals.push_back(rewriter.create( + loc, lhs, + /*offsets=*/ArrayRef{0, w * strideW + kw * dilationW, 0}, + /*sizes=*/ArrayRef{nSize, wSizeStep, cSize}, + /*strides=*/ArrayRef{1, 1, 1})); + } + } + } else { + // Extract lhs slice of size {n, wSizeStep * c} + // @ [0, (sw * w + dw * kw) * cSize]. + for (int64_t kw = 0; kw < kwSize; ++kw) { + for (int64_t w = 0; w < wSize; w += wSizeStep) { + lhsVals.push_back(rewriter.create( + loc, lhs, + /*offsets=*/ + ArrayRef{0, (w * strideW + kw * dilationW) * cSize}, + /*sizes=*/ArrayRef{nSize, wSizeStep * cSize}, + /*strides=*/ArrayRef{1, 1})); + } } } // Extract rhs slice of size {c} @ [kw]. @@ -2661,25 +2722,50 @@ rhsVals.push_back(rewriter.create( loc, rhs, /*offsets=*/ArrayRef{kw})); } - // Extract res slice: {n, wSizeStep, c} @ [0, w, 0]. - for (int64_t w = 0; w < wSize; w += wSizeStep) { - resVals.push_back(rewriter.create( - loc, res, - /*offsets=*/ArrayRef{0, w, 0}, - /*sizes=*/ArrayRef{nSize, wSizeStep, cSize}, - /*strides=*/ArrayRef{1, 1, 1})); + + // Extract res slice + if (!flattenChannelDim) { + // Regular case: {n, wSizeStep, c} @ [0, w, 0]. + for (int64_t w = 0; w < wSize; w += wSizeStep) { + resVals.push_back(rewriter.create( + loc, res, + /*offsets=*/ArrayRef{0, w, 0}, + /*sizes=*/ArrayRef{nSize, wSizeStep, cSize}, + /*strides=*/ArrayRef{1, 1, 1})); + } + } else { + // Flattened case: {n, wSizeStep * c} @ [0, w]. + for (int64_t w = 0; w < wSize; w += wSizeStep) { + resVals.push_back(rewriter.create( + loc, res, + /*offsets=*/ArrayRef{0, w * cSize}, + /*sizes=*/ArrayRef{nSize, wSizeStep * cSize}, + /*strides=*/ArrayRef{1, 1})); + } } auto linearIndex = [&](int64_t kw, int64_t w) { return kw * (wSize / wSizeStep) + w; }; - // Compute contraction: O{n, w, c} += I{n, sw * w + dw * kw, c} * F{c} - for (int64_t kw = 0; kw < kwSize; ++kw) { - for (int64_t w = 0; w < wSize; w += wSizeStep) { - resVals[w] = depthwiseConv1dSliceAsMulAcc(rewriter, loc, - lhsVals[linearIndex(kw, w)], - rhsVals[kw], resVals[w]); + // Compute contraction + // 1. Regular: O{n, w, c} += I{n, sw * w + dw * kw, c} * F{c} + // 2. Flattened: O{n, w * c} += I{n, (sw * w + dw * kw) * c} * F{c} + if (!flattenChannelDim) { + for (int64_t kw = 0; kw < kwSize; ++kw) { + for (int64_t w = 0; w < wSize; w += wSizeStep) { + resVals[w] = depthwiseConv1dSliceAsMulAcc(rewriter, loc, + lhsVals[linearIndex(kw, w)], + rhsVals[kw], resVals[w]); + } + } + } else { + for (int64_t kw = 0; kw < kwSize; ++kw) { + for (int64_t w = 0; w < wSize; w += wSizeStep) { + resVals[w] = depthwiseConv1dFlatSliceAsMulAcc( + rewriter, loc, lhsVals[linearIndex(kw, w)], rhsVals[kw], + resVals[w]); + } } } @@ -2693,18 +2779,40 @@ return rewriter.notifyMatchFailure(op, "failed to create FMA"); } - // Write back res slice: {n, wSizeStep, c} @ [0, w, 0]. - // This does not depend on kw. - for (int64_t w = 0; w < wSize; w += wSizeStep) { - res = rewriter.create( - loc, resVals[w], res, - /*offsets=*/ArrayRef{0, w, 0}, - /*strides=*/ArrayRef{1, 1, 1}); + // Write back res slice. This does not depend on kw. + if (!flattenChannelDim) { + // Regular case: {n, wSizeStep, c} @ [0, w, 0] + for (int64_t w = 0; w < wSize; w += wSizeStep) { + res = rewriter.create( + loc, resVals[w], res, + /*offsets=*/ArrayRef{0, w, 0}, + /*strides=*/ArrayRef{1, 1, 1}); + } + } else { + // Flattened case: {n, wSizeStep * c} @ [0, w]. + for (int64_t w = 0; w < wSize; w += wSizeStep) { + res = rewriter.create( + loc, resVals[w], res, + /*offsets=*/ArrayRef{0, w * cSize}, + /*strides=*/ArrayRef{1, 1}); + } } //===------------------------------------------------------------------===// // End vector-only rewrite part //===------------------------------------------------------------------===// + if (flattenChannelDim) { + // Write back res slice of size {n, w * c} @ [0, 0]. + vector::TransferWriteOp resWrite = + rewriter.create(loc, res, resFlat, + ValueRange{zero, zero}); + + // Re-expand shape + return rewriter + .create(loc, resShapedType, + resWrite.getResult(), reassociation) + .getOperation(); + } // Write back res slice of size {n, w, c} @ [0, 0, 0]. return rewriter .create(loc, res, resShaped, @@ -2735,6 +2843,39 @@ return rewriter.create(loc, mul, res); } + /// Lower lhs{n, w * c} * rhs{c} -> res{n, w * c} to MulAcc + Value depthwiseConv1dFlatSliceAsMulAcc(RewriterBase &rewriter, Location loc, + Value lhs, Value rhs, Value res) { + auto rhsTy = rhs.getType().cast(); + auto resTy = res.getType().cast(); + + lhs = promote(rewriter, loc, lhs, resTy); + + auto rhsSize = rhs.getType().cast().getShape()[0]; + auto resSize = res.getType().cast().getShape()[1]; + + SmallVector indicies; + for (int i = 0; i < resSize / rhsSize; ++i) { + for (int j = 0; j < rhsSize; ++j) + indicies.push_back(j); + } + + rhs = rewriter.create(loc, rhs, rhs, indicies); + + rhs = rewriter.create( + loc, resTy.clone(rhsTy.getElementType()), rhs); + rhs = promote(rewriter, loc, rhs, resTy); + + if (!lhs || !rhs) + return nullptr; + + if (resTy.getElementType().isa()) + return rewriter.create(loc, lhs, rhs, res); + + auto mul = rewriter.create(loc, lhs, rhs); + return rewriter.create(loc, mul, res); + } + /// Entry point for non-channeled convolution: /// {{w + kw}, {kw}, {w}} FailureOr generateNonChanneledConv() { diff --git a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir --- a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir +++ b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir @@ -907,3 +907,35 @@ // CHECK: %[[V7:.+]] = arith.addf %[[V5]], %[[V6]] : vector<4x3x2xf32> // CHECK: %[[V8:.+]] = vector.transpose %[[V7]], [0, 2, 1] : vector<4x3x2xf32> to vector<4x2x3xf32> // CHECK: vector.transfer_write %[[V8:.+]], %[[OUTPUT]][%[[Vc0]], %[[Vc0]], %[[Vc0]]] {in_bounds = [true, true, true]} : vector<4x2x3xf32>, memref<4x2x3xf32> + +// ----- + +func.func @flatten(%input: tensor<1x8x3xi8>, %filter: tensor<1x3xi8>, %output: tensor<1x8x3xi8>) -> (tensor<1x8x3xi8>) { + %res = linalg.depthwise_conv_1d_nwc_wc + {dilations = dense<1> : vector<1xi64>, + strides = dense<1> : vector<1xi64>} + ins(%input, %filter : tensor<1x8x3xi8>, tensor<1x3xi8>) + outs(%output : tensor<1x8x3xi8>) -> tensor<1x8x3xi8> + return %res : tensor<1x8x3xi8> +} + +// CHECK-LABEL: func.func @flatten( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x8x3xi8>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x3xi8>, +// CHECK-SAME: %[[VAL_2:.*]]: tensor<1x8x3xi8>) -> tensor<1x8x3xi8> { +// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_4:.*]] = arith.constant 0 : i8 +// CHECK: %[[VAL_5:.*]] = vector.transfer_read %[[VAL_1]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : tensor<1x3xi8>, vector<1x3xi8> +// CHECK: %[[VAL_6:.*]] = tensor.collapse_shape %[[VAL_0]] {{\[\[}}0], [1, 2]] : tensor<1x8x3xi8> into tensor<1x24xi8> +// CHECK: %[[VAL_7:.*]] = tensor.collapse_shape %[[VAL_2]] {{\[\[}}0], [1, 2]] : tensor<1x8x3xi8> into tensor<1x24xi8> +// CHECK: %[[VAL_8:.*]] = vector.transfer_read %[[VAL_6]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : tensor<1x24xi8>, vector<1x24xi8> +// CHECK: %[[VAL_9:.*]] = vector.transfer_read %[[VAL_7]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : tensor<1x24xi8>, vector<1x24xi8> +// CHECK: %[[VAL_10:.*]] = vector.extract %[[VAL_5]][0] : vector<1x3xi8> +// CHECK: %[[VAL_11:.*]] = vector.shuffle %[[VAL_10]], %[[VAL_10]] [0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2] : vector<3xi8>, vector<3xi8> +// CHECK: %[[VAL_12:.*]] = vector.broadcast %[[VAL_11]] : vector<24xi8> to vector<1x24xi8> +// CHECK: %[[VAL_13:.*]] = arith.muli %[[VAL_8]], %[[VAL_12]] : vector<1x24xi8> +// CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_13]], %[[VAL_9]] : vector<1x24xi8> +// CHECK: %[[VAL_15:.*]] = vector.transfer_write %[[VAL_14]], %[[VAL_7]]{{\[}}%[[VAL_3]], %[[VAL_3]]] {in_bounds = [true, true]} : vector<1x24xi8>, tensor<1x24xi8> +// CHECK: %[[VAL_16:.*]] = tensor.expand_shape %[[VAL_15]] {{\[\[}}0], [1, 2]] : tensor<1x24xi8> into tensor<1x8x3xi8> +// CHECK: return %[[VAL_16]] : tensor<1x8x3xi8> +// CHECK: }