diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -62,6 +62,112 @@ return res; } +/// Helper function to extract the input slices after filter is unrolled along +/// kw. +static SmallVector +extractConvInputSlices(RewriterBase &rewriter, Location loc, Value input, + int64_t nSize, int64_t wSize, int64_t cSize, + int64_t kwSize, int strideW, int dilationW, + int64_t wSizeStep, bool isSingleChanneled) { + SmallVector result; + if (isSingleChanneled) { + // Extract input slice of size {wSizeStep} @ [w + kw] for non-channeled + // convolution. + SmallVector sizes{wSizeStep}; + SmallVector strides{1}; + for (int64_t kw = 0; kw < kwSize; ++kw) { + for (int64_t w = 0; w < wSize; w += wSizeStep) { + result.push_back(rewriter.create( + loc, input, /*offsets=*/ArrayRef{w + kw}, sizes, strides)); + } + } + } else { + // Extract lhs slice of size {n, wSizeStep, c} @ [0, sw * w + dw * kw, 0] + // for channeled convolution. + SmallVector sizes{nSize, wSizeStep, cSize}; + SmallVector strides{1, 1, 1}; + for (int64_t kw = 0; kw < kwSize; ++kw) { + for (int64_t w = 0; w < wSize; w += wSizeStep) { + result.push_back(rewriter.create( + loc, input, + /*offsets=*/ArrayRef{0, w * strideW + kw * dilationW, 0}, + sizes, strides)); + } + } + } + return result; +} + +/// Helper function to extract the filter slices after filter is unrolled along +/// kw. +static SmallVector extractConvFilterSlices(RewriterBase &rewriter, + Location loc, Value filter, + int64_t kwSize) { + SmallVector result; + // Extract rhs slice of size [{c, f} for channeled convolutions and {1} for + // non-chanelled convolution] @ [kw]. + for (int64_t kw = 0; kw < kwSize; ++kw) { + result.push_back(rewriter.create( + loc, filter, /*offsets=*/ArrayRef{kw})); + } + return result; +} + +/// Helper function to extract the result slices after filter is unrolled along +/// kw. +static SmallVector +extractConvResultSlices(RewriterBase &rewriter, Location loc, Value res, + int64_t nSize, int64_t wSize, int64_t fSize, + int64_t wSizeStep, bool isSingleChanneled) { + SmallVector result; + if (isSingleChanneled) { + // Extract res slice: {wSizeStep} @ [w] for non-channeled convolution. + SmallVector sizes{wSizeStep}; + SmallVector strides{1}; + for (int64_t w = 0; w < wSize; w += wSizeStep) { + result.push_back(rewriter.create( + loc, res, /*offsets=*/ArrayRef{w}, sizes, strides)); + } + } else { + // Extract res slice: {n, wSizeStep, f} @ [0, w, 0] for channeled + // convolution. + SmallVector sizes{nSize, wSizeStep, fSize}; + SmallVector strides{1, 1, 1}; + for (int64_t w = 0; w < wSize; w += wSizeStep) { + result.push_back(rewriter.create( + loc, res, /*offsets=*/ArrayRef{0, w, 0}, sizes, strides)); + } + } + return result; +} + +/// Helper function to insert the computed result slices. +static Value insertConvResultSlices(RewriterBase &rewriter, Location loc, + Value res, int64_t wSize, int64_t wSizeStep, + SmallVectorImpl &resVals, + bool isSingleChanneled) { + + if (isSingleChanneled) { + // Write back res slice: {wSizeStep} @ [w] for non-channeled convolution. + // This does not depend on kw. + SmallVector strides{1}; + for (int64_t w = 0; w < wSize; w += wSizeStep) { + res = rewriter.create( + loc, resVals[w], res, /*offsets=*/ArrayRef{w}, strides); + } + } else { + // Write back res slice: {n, wSizeStep, f} @ [0, w, 0] for channeled + // convolution. This does not depend on kw. + SmallVector strides{1, 1, 1}; + for (int64_t w = 0; w < wSize; w += wSizeStep) { + res = rewriter.create( + loc, resVals[w], res, /*offsets=*/ArrayRef{0, w, 0}, + strides); + } + } + return res; +} + /// Contains the vectorization state and related methods used across the /// vectorization process of a given operation. struct VectorizationState { @@ -334,6 +440,7 @@ /// Helper enum to represent conv1d input traversal order. enum class Conv1DOpOrder { + W, // Corresponds to non-channeled 1D convolution operation. Ncw, // Corresponds to operation that traverses the input in (n, c, w) order. Nwc // Corresponds to operation that traverses the input in (n, w, c) order. }; @@ -2055,6 +2162,15 @@ /// Generate a vector implementation for either: /// ``` +/// Op def: ( w, kw ) +/// Iters: ({Par(), Red()}) +/// Layout: {{w + kw}, {kw}, {w}} +/// ``` +/// kw is unrolled. +/// +/// or +/// +/// ``` /// Op def: ( n, w, c, kw, f ) /// Iters: ({Par(), Par(), Par(), Red(), Red()}) /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}} @@ -2095,8 +2211,10 @@ resShapedType = resShaped.getType().dyn_cast(); if (!lhsShapedType || !rhsShapedType || !resShapedType) return; - // LHS has dimension NCW/NWC and RES has dimension NFW/NCW/NWF/NWC. - if (lhsShapedType.getRank() != 3 || resShapedType.getRank() != 3) + // (LHS has dimension NCW/NWC and RES has dimension NFW/NCW/NWF/NWC) OR + // (non-channeled convolution -> LHS and RHS both have single dimensions). + if (!((lhsShapedType.getRank() == 3 && resShapedType.getRank() == 3) || + (lhsShapedType.getRank() == 1 && resShapedType.getRank() == 1))) return; Operation *reduceOp = matchLinalgReduction(linalgOp.getDpsInitOperand(0)); @@ -2115,7 +2233,7 @@ auto rhsRank = rhsShapedType.getRank(); switch (oper) { case Conv: - if (rhsRank != 2 && rhsRank!= 3) + if (rhsRank != 1 && rhsRank != 2 && rhsRank != 3) return; break; case Pool: @@ -2129,6 +2247,15 @@ /// Generate a vector implementation for: /// ``` + /// Op def: ( w, kw ) + /// Iters: ({Par(), Red()}) + /// Layout: {{w + kw}, {kw}, {w}} + /// ``` + /// kw is always unrolled. + /// + /// or + /// + /// ``` /// Op def: ( n, w, c, kw, f ) /// Iters: ({Par(), Par(), Par(), Red(), Red()}) /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}} @@ -2142,7 +2269,21 @@ int64_t nSize, wSize, cSize, kwSize, fSize; SmallVector lhsShape, rhsShape, resShape; + bool isSingleChanneled = (conv1DOpOrder == Conv1DOpOrder::W); switch (conv1DOpOrder) { + case Conv1DOpOrder::W: + // Initialize unused dimensions + nSize = fSize = cSize = 0; + // out{W} + bindShapeDims(resShapedType, wSize); + // kernel{kw} + bindShapeDims(rhsShapedType, kwSize); + lhsShape = {// iw = ow + kw - 1 + // (i.e. 16 convolved with 3 -> 14) + (wSize + kwSize - 1)}; + rhsShape = {kwSize}; + resShape = {wSize}; + break; case Conv1DOpOrder::Nwc: // out{n, w, f} bindShapeDims(resShapedType, nSize, wSize, fSize); @@ -2220,24 +2361,27 @@ auto lhsType = VectorType::get(lhsShape, lhsEltType); auto rhsType = VectorType::get(rhsShape, rhsEltType); auto resType = VectorType::get(resShape, resEltType); - // Read lhs slice of size {w * strideW + kw * dilationW, c, f} @ [0, 0, - // 0]. - Value lhs = rewriter.create( - loc, lhsType, lhsShaped, ValueRange{zero, zero, zero}); - // Read rhs slice of size {kw, c, f} @ [0, 0, 0]. + // Zero padding with the corresponding dimensions for lhs, rhs and res. + SmallVector lhsPadding(lhsShape.size(), zero); + SmallVector rhsPadding(rhsShape.size(), zero); + SmallVector resPadding(resShape.size(), zero); + + // Read the whole lhs, rhs and res in one shot (with zero padding). + Value lhs = rewriter.create(loc, lhsType, lhsShaped, + lhsPadding); // This is needed only for Conv. Value rhs = nullptr; if (oper == Conv) - rhs = rewriter.create( - loc, rhsType, rhsShaped, ValueRange{zero, zero, zero}); - // Read res slice of size {n, w, f} @ [0, 0, 0]. - Value res = rewriter.create( - loc, resType, resShaped, ValueRange{zero, zero, zero}); - - // The base vectorization case is input: {n,w,c}, weight: {kw,c,f}, output: - // {n,w,f}. To reuse the base pattern vectorization case, we do pre - // transpose on input, weight, and output. + rhs = rewriter.create(loc, rhsType, rhsShaped, + rhsPadding); + Value res = rewriter.create(loc, resType, resShaped, + resPadding); + + // The base vectorization case for channeled convolution is input: {n,w,c}, + // weight: {kw,c,f}, output: {n,w,f}. To reuse the base pattern + // vectorization case, we do pre transpose on input, weight, and output. switch (conv1DOpOrder) { + case Conv1DOpOrder::W: case Conv1DOpOrder::Nwc: // Base case, so no transposes necessary. break; @@ -2264,45 +2408,35 @@ //===------------------------------------------------------------------===// // Unroll along kw and read slices of lhs and rhs. SmallVector lhsVals, rhsVals, resVals; - // Extract lhs slice of size {n, wSizeStep, c} @ [0, sw * w + dw * kw, 0]. - for (int64_t kw = 0; kw < kwSize; ++kw) { - for (int64_t w = 0; w < wSize; w += wSizeStep) { - lhsVals.push_back(rewriter.create( - loc, lhs, - /*offsets=*/ArrayRef{0, w * strideW + kw * dilationW, 0}, - /*sizes=*/ArrayRef{nSize, wSizeStep, cSize}, - /*strides=*/ArrayRef{1, 1, 1})); - } - } - // Extract rhs slice of size {c, f} @ [kw]. + lhsVals = extractConvInputSlices(rewriter, loc, lhs, nSize, wSize, cSize, + kwSize, strideW, dilationW, wSizeStep, + isSingleChanneled); // Do not do for pooling. if (oper == Conv) - for (int64_t kw = 0; kw < kwSize; ++kw) { - rhsVals.push_back(rewriter.create( - loc, rhs, /*offsets=*/ArrayRef{kw})); - } - // Extract res slice: {n, wSizeStep, f} @ [0, w, 0]. - for (int64_t w = 0; w < wSize; w += wSizeStep) { - resVals.push_back(rewriter.create( - loc, res, - /*offsets=*/ArrayRef{0, w, 0}, - /*sizes=*/ArrayRef{nSize, wSizeStep, fSize}, - /*strides=*/ArrayRef{1, 1, 1})); - } + rhsVals = extractConvFilterSlices(rewriter, loc, rhs, kwSize); + resVals = extractConvResultSlices(rewriter, loc, res, nSize, wSize, fSize, + wSizeStep, isSingleChanneled); auto linearIndex = [&](int64_t kw, int64_t w) { return kw * (wSize / wSizeStep) + w; }; // Compute contraction: O{n, w, f} += I{n, sw * w + dw * kw, c} * F{c, f} or + // perform outerproduct for non-channeled convolution or // perform simple arith operation for pooling for (int64_t kw = 0; kw < kwSize; ++kw) { for (int64_t w = 0; w < wSize; w += wSizeStep) { switch (oper) { case Conv: - resVals[w] = conv1dSliceAsContraction(rewriter, loc, - lhsVals[linearIndex(kw, w)], - rhsVals[kw], resVals[w]); + if (isSingleChanneled) { + resVals[w] = conv1dSliceAsOuterProduct(rewriter, loc, + lhsVals[linearIndex(kw, w)], + rhsVals[kw], resVals[w]); + } else { + resVals[w] = conv1dSliceAsContraction(rewriter, loc, + lhsVals[linearIndex(kw, w)], + rhsVals[kw], resVals[w]); + } break; case Pool: resVals[w] = pool1dSlice(rewriter, loc, lhsVals[linearIndex(kw, w)], @@ -2312,22 +2446,17 @@ } } - // Write back res slice: {n, wSizeStep, f} @ [0, w, 0]. - // This does not depend on kw. - for (int64_t w = 0; w < wSize; w += wSizeStep) { - res = rewriter.create( - loc, resVals[w], res, - /*offsets=*/ArrayRef{0, w, 0}, - /*strides=*/ArrayRef{1, 1, 1}); - } + res = insertConvResultSlices(rewriter, loc, res, wSize, wSizeStep, resVals, + isSingleChanneled); //===------------------------------------------------------------------===// // End vector-only rewrite part //===------------------------------------------------------------------===// - // The base vectorization case is output: {n,w,f} + // The base vectorization case for channeled convolution is output: {n,w,f} // To reuse the result from base pattern vectorization case, we post // transpose the base case result. switch (conv1DOpOrder) { + case Conv1DOpOrder::W: case Conv1DOpOrder::Nwc: // Base case, so no transposes necessary. break; @@ -2339,10 +2468,8 @@ } } - // Write back res slice of size {n, w, f} @ [0, 0, 0]. return rewriter - .create(loc, res, resShaped, - ValueRange{zero, zero, zero}) + .create(loc, res, resShaped, resPadding) .getOperation(); } @@ -2359,6 +2486,14 @@ /*iteratorTypes=*/ArrayRef{par, par, par, red}); } + // Create an outerproduct: lhs{w} * rhs{1} -> res{w} for single channel + // convolution. + Value conv1dSliceAsOuterProduct(RewriterBase &rewriter, Location loc, + Value lhs, Value rhs, Value res) { + return rewriter.create( + loc, res.getType(), lhs, rhs, res, vector::CombiningKind::ADD); + } + // Create a reduction: lhs{n, w, c} -> res{n, w, c} Value pool1dSlice(RewriterBase &rewriter, Location loc, Value lhs, Value res) { @@ -2531,6 +2666,24 @@ return rewriter.create(loc, mul, res); } + /// Entry point for non-channeled convolution: + /// {{w + kw}, {kw}, {w}} + FailureOr generateNonChanneledConv() { + AffineExpr w, kw; + bindDims(ctx, w, kw); + if (!iters({Par(), Red()})) + return rewriter.notifyMatchFailure(op, + "failed to match conv::W 1-par 1-red"); + + // No transposition needed. + if (layout({/*lhsIndex*/ {w + kw}, + /*rhsIndex*/ {kw}, + /*resIndex*/ {w}})) + return conv(Conv1DOpOrder::W); + + return rewriter.notifyMatchFailure(op, "not a conv::W layout"); + } + /// Entry point that transposes into the common form: /// {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}} FailureOr generateNwcConv() { @@ -2691,7 +2844,10 @@ auto stride = strides ? *strides.getValues().begin() : 1; auto dilation = dilations ? *dilations.getValues().begin() : 1; Conv1DGenerator e(rewriter, op, stride, dilation); - auto res = e.generateNwcConv(); + auto res = e.generateNonChanneledConv(); + if (succeeded(res)) + return res; + res = e.generateNwcConv(); if (succeeded(res)) return res; res = e.generateNcwConv(); diff --git a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir --- a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir +++ b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir @@ -461,6 +461,59 @@ // CHECK: vector.transfer_write %[[RES]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]] +// ----- + +func.func @conv1d_8_tensor(%input: tensor<11xf32>, %filter: tensor<4xf32>, %output: tensor<8xf32>) -> tensor<8xf32> { + %0 = linalg.conv_1d ins(%input, %filter : tensor<11xf32>, tensor<4xf32>) + outs(%output : tensor<8xf32>) -> tensor<8xf32> + return %0 : tensor<8xf32> +} + +// CHECK: func @conv1d_8_tensor +// CHECK-SAME: (%[[INPUT:.+]]: tensor<11xf32>, %[[FILTER:.+]]: tensor<4xf32>, %[[OUTPUT:.+]]: tensor<8xf32>) + +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[F0:.+]] = arith.constant 0.000000e+00 : f32 + +/// Read the whole data in one shot. +// CHECK-DAG: %[[V_INPUT_R:.+]] = vector.transfer_read %[[INPUT]][%[[C0]]], %[[F0]] +// CHECK-DAG: %[[V_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]]], %[[F0]] +// CHECK-DAG: %[[V_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]]], %[[F0]] + +// CHECK: %[[V_INPUT_0:.+]] = vector.extract_strided_slice %[[V_INPUT_R]] +// CHECK-SAME: {offsets = [0], sizes = [8], strides = [1]} : vector<11xf32> to vector<8xf32> +// CHECK: %[[V_INPUT_1:.+]] = vector.extract_strided_slice %[[V_INPUT_R]] +// CHECK-SAME: {offsets = [1], sizes = [8], strides = [1]} : vector<11xf32> to vector<8xf32> +// CHECK: %[[V_INPUT_2:.+]] = vector.extract_strided_slice %[[V_INPUT_R]] +// CHECK-SAME: {offsets = [2], sizes = [8], strides = [1]} : vector<11xf32> to vector<8xf32> +// CHECK: %[[V_INPUT_3:.+]] = vector.extract_strided_slice %[[V_INPUT_R]] +// CHECK-SAME: {offsets = [3], sizes = [8], strides = [1]} : vector<11xf32> to vector<8xf32> + +// CHECK: %[[V_FILTER_0:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<4xf32> +// CHECK: %[[V_FILTER_1:.+]] = vector.extract %[[V_FILTER_R]][1] : vector<4xf32> +// CHECK: %[[V_FILTER_2:.+]] = vector.extract %[[V_FILTER_R]][2] : vector<4xf32> +// CHECK: %[[V_FILTER_3:.+]] = vector.extract %[[V_FILTER_R]][3] : vector<4xf32> + +/// w == 0, kw == 0 +// CHECK: %[[RES_0:.+]] = vector.outerproduct +// CHECK-SAME: %[[V_INPUT_0]], %[[V_FILTER_0]], %[[V_OUTPUT_R]] {kind = #vector.kind} +// CHECK-SAME: : vector<8xf32>, f32 +/// w == 1, kw == 1 +// CHECK: %[[RES_1:.+]] = vector.outerproduct +// CHECK-SAME: %[[V_INPUT_1]], %[[V_FILTER_1]], %[[RES_0]] {kind = #vector.kind} +// CHECK-SAME: : vector<8xf32>, f32 +/// w == 2, kw == 2 +// CHECK: %[[RES_2:.+]] = vector.outerproduct +// CHECK-SAME: %[[V_INPUT_2]], %[[V_FILTER_2]], %[[RES_1]] {kind = #vector.kind} +// CHECK-SAME: : vector<8xf32>, f32 +/// w == 3, kw == 3 +// CHECK: %[[RES_3:.+]] = vector.outerproduct +// CHECK-SAME: %[[V_INPUT_3]], %[[V_FILTER_3]], %[[RES_2]] {kind = #vector.kind} +// CHECK-SAME: : vector<8xf32>, f32 + +// Write the result back in one shot. +// CHECK: vector.transfer_write %[[RES_3]], %[[OUTPUT]][%[[C0]]] + // ----- func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref(%input: memref<3x5x4xf32>, %filter: memref<2x4xf32>, %output: memref<3x2x4xf32>) {