Index: mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp =================================================================== --- mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -31,8 +31,8 @@ #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" -#include #include +#include using namespace mlir; using namespace mlir::linalg; @@ -62,6 +62,115 @@ return res; } +/// Helper function to extract the input slices after filter is unrolled along +/// kw. +static SmallVector +extractConvInputSlices(RewriterBase &rewriter, Location loc, Value input, + int64_t nSize, int64_t wSize, int64_t cSize, + int64_t kwSize, int strideW, int dilationW, + int64_t wSizeStep, bool isSingleChanneled) { + SmallVector result; + if (isSingleChanneled) { + // Extract input slice of size {wSizeStep} @ [w + kw] for non-channeled + // convolution. + for (int64_t kw = 0; kw < kwSize; ++kw) { + for (int64_t w = 0; w < wSize; w += wSizeStep) { + result.push_back(rewriter.create( + loc, input, + /*offsets=*/ArrayRef{w + kw}, + /*sizes=*/ArrayRef{wSizeStep}, + /*strides=*/ArrayRef{1})); + } + } + } else { + // Extract lhs slice of size {n, wSizeStep, c} @ [0, sw * w + dw * kw, 0] + // for channeled convolution. + for (int64_t kw = 0; kw < kwSize; ++kw) { + for (int64_t w = 0; w < wSize; w += wSizeStep) { + result.push_back(rewriter.create( + loc, input, + /*offsets=*/ArrayRef{0, w * strideW + kw * dilationW, 0}, + /*sizes=*/ArrayRef{nSize, wSizeStep, cSize}, + /*strides=*/ArrayRef{1, 1, 1})); + } + } + } + return result; +} + +/// Helper function to extract the filter slices after filter is unrolled along +/// kw. +static SmallVector extractConvFilterSlices(RewriterBase &rewriter, + Location loc, Value filter, + int64_t kwSize) { + SmallVector result; + // Extract rhs slice of size [{c, f} for channeled convolutions and {1} for + // non-chanelled convolution] @ [kw]. + for (int64_t kw = 0; kw < kwSize; ++kw) { + result.push_back(rewriter.create( + loc, filter, /*offsets=*/ArrayRef{kw})); + } + return result; +} + +/// Helper function to extract the result slices after filter is unrolled along +/// kw. +static SmallVector +extractConvResultSlices(RewriterBase &rewriter, Location loc, Value res, + int64_t nSize, int64_t wSize, int64_t fSize, + int64_t wSizeStep, bool isSingleChanneled) { + SmallVector result; + if (isSingleChanneled) { + // Extract res slice: {wSizeStep} @ [w] for non-channeled convolution. + for (int64_t w = 0; w < wSize; w += wSizeStep) { + result.push_back(rewriter.create( + loc, res, + /*offsets=*/ArrayRef{w}, + /*sizes=*/ArrayRef{wSizeStep}, + /*strides=*/ArrayRef{1})); + } + } else { + // Extract res slice: {n, wSizeStep, f} @ [0, w, 0] for channeled + // convolution. + for (int64_t w = 0; w < wSize; w += wSizeStep) { + result.push_back(rewriter.create( + loc, res, + /*offsets=*/ArrayRef{0, w, 0}, + /*sizes=*/ArrayRef{nSize, wSizeStep, fSize}, + /*strides=*/ArrayRef{1, 1, 1})); + } + } + return result; +} + +/// Helper function to insert the computed result slices. +static Value insertConvResultSlices(RewriterBase &rewriter, Location loc, + Value res, int64_t wSize, int64_t wSizeStep, + SmallVector &resVals, + bool isSingleChanneled) { + + if (isSingleChanneled) { + // Write back res slice: {wSizeStep} @ [w] for non-channeled convolution. + // This does not depend on kw. + for (int64_t w = 0; w < wSize; w += wSizeStep) { + res = rewriter.create( + loc, resVals[w], res, + /*offsets=*/ArrayRef{w}, + /*strides=*/ArrayRef{1}); + } + } else { + // Write back res slice: {n, wSizeStep, f} @ [0, w, 0] for channeled + // convolution. This does not depend on kw. + for (int64_t w = 0; w < wSize; w += wSizeStep) { + res = rewriter.create( + loc, resVals[w], res, + /*offsets=*/ArrayRef{0, w, 0}, + /*strides=*/ArrayRef{1, 1, 1}); + } + } + return res; +} + /// Contains the vectorization state and related methods used across the /// vectorization process of a given operation. struct VectorizationState { @@ -334,6 +443,7 @@ /// Helper enum to represent conv1d input traversal order. enum class Conv1DOpOrder { + W, // Corresponds to non-channeled 1D convolution operation. Ncw, // Corresponds to operation that traverses the input in (n, c, w) order. Nwc // Corresponds to operation that traverses the input in (n, w, c) order. }; @@ -417,8 +527,7 @@ vector::BroadcastableToResult::Success) return value; Location loc = b.getInsertionPoint()->getLoc(); - return b.createOrFold(loc, targetVectorType, - value); + return b.createOrFold(loc, targetVectorType, value); } /// Create MultiDimReductionOp to compute the reduction for `reductionOp`. This @@ -1798,14 +1907,14 @@ static void bindShapeDims(ShapedType shapedType) {} template -static void bindShapeDims(ShapedType shapedType, IntTy &val, IntTy2 &...vals) { +static void bindShapeDims(ShapedType shapedType, IntTy &val, IntTy2 &... vals) { val = shapedType.getShape()[N]; bindShapeDims(shapedType, vals...); } /// Bind a pack of int& to the leading dimensions of shapedType.getShape(). template -static void bindShapeDims(ShapedType shapedType, IntTy &...vals) { +static void bindShapeDims(ShapedType shapedType, IntTy &... vals) { bindShapeDims<0>(shapedType, vals...); } @@ -1832,6 +1941,15 @@ /// Generate a vector implementation for either: /// ``` +/// Op def: ( w, kw ) +/// Iters: ({Par(), Red()}) +/// Layout: {{w + kw}, {kw}, {w}} +/// ``` +/// kw is unrolled. +/// +/// or +/// +/// ``` /// Op def: ( n, w, c, kw, f ) /// Iters: ({Par(), Par(), Par(), Red(), Red()}) /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}} @@ -1872,8 +1990,10 @@ resShapedType = resShaped.getType().dyn_cast(); if (!lhsShapedType || !rhsShapedType || !resShapedType) return; - // LHS has dimension NCW/NWC and RES has dimension NFW/NCW/NWF/NWC. - if (lhsShapedType.getRank() != 3 || resShapedType.getRank() != 3) + // (LHS has dimension NCW/NWC and RES has dimension NFW/NCW/NWF/NWC) OR + // (non-channeled convolution -> LHS and RHS both have single dimensions). + if (!((lhsShapedType.getRank() == 3 && resShapedType.getRank() == 3) || + (lhsShapedType.getRank() == 1 && resShapedType.getRank() == 1))) return; Operation *reduceOp = matchLinalgReduction(linalgOp.getDpsInitOperand(0)); @@ -1892,7 +2012,7 @@ auto rhsRank = rhsShapedType.getRank(); switch (oper) { case Conv: - if (rhsRank != 2 && rhsRank!= 3) + if (rhsRank != 1 && rhsRank != 2 && rhsRank != 3) return; break; case Pool: @@ -1906,6 +2026,15 @@ /// Generate a vector implementation for: /// ``` + /// Op def: ( w, kw ) + /// Iters: ({Par(), Red()}) + /// Layout: {{w + kw}, {kw}, {w}} + /// ``` + /// kw is always unrolled. + /// + /// or + /// + /// ``` /// Op def: ( n, w, c, kw, f ) /// Iters: ({Par(), Par(), Par(), Red(), Red()}) /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}} @@ -1919,7 +2048,21 @@ int64_t nSize, wSize, cSize, kwSize, fSize; SmallVector lhsShape, rhsShape, resShape; + bool isSingleChanneled = (conv1DOpOrder == Conv1DOpOrder::W); switch (conv1DOpOrder) { + case Conv1DOpOrder::W: + // Initialize unused dimensions + nSize = fSize = cSize = 0; + // out{W} + bindShapeDims(resShapedType, wSize); + // kernel{kw} + bindShapeDims(rhsShapedType, kwSize); + lhsShape = {// iw = ow + kw - 1 + // (i.e. 16 convolved with 3 -> 14) + (wSize + kwSize - 1)}; + rhsShape = {kwSize}; + resShape = {wSize}; + break; case Conv1DOpOrder::Nwc: // out{n, w, f} bindShapeDims(resShapedType, nSize, wSize, fSize); @@ -1997,24 +2140,27 @@ auto lhsType = VectorType::get(lhsShape, lhsEltType); auto rhsType = VectorType::get(rhsShape, rhsEltType); auto resType = VectorType::get(resShape, resEltType); - // Read lhs slice of size {w * strideW + kw * dilationW, c, f} @ [0, 0, - // 0]. - Value lhs = rewriter.create( - loc, lhsType, lhsShaped, ValueRange{zero, zero, zero}); - // Read rhs slice of size {kw, c, f} @ [0, 0, 0]. + // Zero padding with the corresponding dimensions for lhs, rhs and res. + SmallVector lhsPadding(lhsShape.size(), zero); + SmallVector rhsPadding(rhsShape.size(), zero); + SmallVector resPadding(resShape.size(), zero); + + // Read the whole lhs, rhs and res in one shot (with zero padding). + Value lhs = rewriter.create(loc, lhsType, lhsShaped, + lhsPadding); // This is needed only for Conv. Value rhs = nullptr; if (oper == Conv) - rhs = rewriter.create( - loc, rhsType, rhsShaped, ValueRange{zero, zero, zero}); - // Read res slice of size {n, w, f} @ [0, 0, 0]. - Value res = rewriter.create( - loc, resType, resShaped, ValueRange{zero, zero, zero}); - - // The base vectorization case is input: {n,w,c}, weight: {kw,c,f}, output: - // {n,w,f}. To reuse the base pattern vectorization case, we do pre - // transpose on input, weight, and output. + rhs = rewriter.create(loc, rhsType, rhsShaped, + rhsPadding); + Value res = rewriter.create(loc, resType, resShaped, + resPadding); + + // The base vectorization case for channeled convolution is input: {n,w,c}, + // weight: {kw,c,f}, output: {n,w,f}. To reuse the base pattern + // vectorization case, we do pre transpose on input, weight, and output. switch (conv1DOpOrder) { + case Conv1DOpOrder::W: case Conv1DOpOrder::Nwc: // Base case, so no transposes necessary. break; @@ -2041,45 +2187,35 @@ //===------------------------------------------------------------------===// // Unroll along kw and read slices of lhs and rhs. SmallVector lhsVals, rhsVals, resVals; - // Extract lhs slice of size {n, wSizeStep, c} @ [0, sw * w + dw * kw, 0]. - for (int64_t kw = 0; kw < kwSize; ++kw) { - for (int64_t w = 0; w < wSize; w += wSizeStep) { - lhsVals.push_back(rewriter.create( - loc, lhs, - /*offsets=*/ArrayRef{0, w * strideW + kw * dilationW, 0}, - /*sizes=*/ArrayRef{nSize, wSizeStep, cSize}, - /*strides=*/ArrayRef{1, 1, 1})); - } - } - // Extract rhs slice of size {c, f} @ [kw]. + lhsVals = extractConvInputSlices(rewriter, loc, lhs, nSize, wSize, cSize, + kwSize, strideW, dilationW, wSizeStep, + isSingleChanneled); // Do not do for pooling. if (oper == Conv) - for (int64_t kw = 0; kw < kwSize; ++kw) { - rhsVals.push_back(rewriter.create( - loc, rhs, /*offsets=*/ArrayRef{kw})); - } - // Extract res slice: {n, wSizeStep, f} @ [0, w, 0]. - for (int64_t w = 0; w < wSize; w += wSizeStep) { - resVals.push_back(rewriter.create( - loc, res, - /*offsets=*/ArrayRef{0, w, 0}, - /*sizes=*/ArrayRef{nSize, wSizeStep, fSize}, - /*strides=*/ArrayRef{1, 1, 1})); - } + rhsVals = extractConvFilterSlices(rewriter, loc, rhs, kwSize); + resVals = extractConvResultSlices(rewriter, loc, res, nSize, wSize, fSize, + wSizeStep, isSingleChanneled); auto linearIndex = [&](int64_t kw, int64_t w) { return kw * (wSize / wSizeStep) + w; }; // Compute contraction: O{n, w, f} += I{n, sw * w + dw * kw, c} * F{c, f} or + // perform outerproduct for non-channeled convolution or // perform simple arith operation for pooling for (int64_t kw = 0; kw < kwSize; ++kw) { for (int64_t w = 0; w < wSize; w += wSizeStep) { switch (oper) { case Conv: - resVals[w] = conv1dSliceAsContraction(rewriter, loc, - lhsVals[linearIndex(kw, w)], - rhsVals[kw], resVals[w]); + if (isSingleChanneled) { + resVals[w] = conv1dSliceAsOuterProduct(rewriter, loc, + lhsVals[linearIndex(kw, w)], + rhsVals[kw], resVals[w]); + } else { + resVals[w] = conv1dSliceAsContraction(rewriter, loc, + lhsVals[linearIndex(kw, w)], + rhsVals[kw], resVals[w]); + } break; case Pool: resVals[w] = pool1dSlice(rewriter, loc, lhsVals[linearIndex(kw, w)], @@ -2089,22 +2225,17 @@ } } - // Write back res slice: {n, wSizeStep, f} @ [0, w, 0]. - // This does not depend on kw. - for (int64_t w = 0; w < wSize; w += wSizeStep) { - res = rewriter.create( - loc, resVals[w], res, - /*offsets=*/ArrayRef{0, w, 0}, - /*strides=*/ArrayRef{1, 1, 1}); - } + res = insertConvResultSlices(rewriter, loc, res, wSize, wSizeStep, resVals, + isSingleChanneled); //===------------------------------------------------------------------===// // End vector-only rewrite part //===------------------------------------------------------------------===// - // The base vectorization case is output: {n,w,f} + // The base vectorization case for channeled convolution is output: {n,w,f} // To reuse the result from base pattern vectorization case, we post // transpose the base case result. switch (conv1DOpOrder) { + case Conv1DOpOrder::W: case Conv1DOpOrder::Nwc: // Base case, so no transposes necessary. break; @@ -2116,10 +2247,8 @@ } } - // Write back res slice of size {n, w, f} @ [0, 0, 0]. return rewriter - .create(loc, res, resShaped, - ValueRange{zero, zero, zero}) + .create(loc, res, resShaped, resPadding) .getOperation(); } @@ -2136,6 +2265,14 @@ /*iteratorTypes=*/ArrayRef{par, par, par, red}); } + // Create an outerproduct: lhs{w} * rhs{1} -> res{w} for single channel + // convolution. + Value conv1dSliceAsOuterProduct(RewriterBase &rewriter, Location loc, + Value lhs, Value rhs, Value res) { + return rewriter.create( + loc, res.getType(), lhs, rhs, res, vector::CombiningKind::ADD); + } + // Create a reduction: lhs{n, w, c} -> res{n, w, c} Value pool1dSlice(RewriterBase &rewriter, Location loc, Value lhs, Value res) { @@ -2234,15 +2371,17 @@ // Compute contraction: O{n, w, c} += I{n, sw * w + dw * kw, c} * F{c} for (int64_t kw = 0; kw < kwSize; ++kw) { for (int64_t w = 0; w < wSize; w += wSizeStep) { - resVals[w] = depthwiseConv1dSliceAsMulAcc( - rewriter, loc, lhsVals[linearIndex(kw, w)], rhsVals[kw], resVals[w]); + resVals[w] = depthwiseConv1dSliceAsMulAcc(rewriter, loc, + lhsVals[linearIndex(kw, w)], + rhsVals[kw], resVals[w]); } } // Its possible we failed to create the Fma. if (!llvm::all_of(resVals, [](Value v) { return v; })) { // Manually revert (in reverse order) to avoid leaving a bad IR state. - for (auto &collection : {resVals, rhsVals, lhsVals, {res, rhs, lhs, zero}}) + for (auto &collection : + {resVals, rhsVals, lhsVals, {res, rhs, lhs, zero}}) for (Value v : collection) rewriter.eraseOp(v.getDefiningOp()); return rewriter.notifyMatchFailure(op, "failed to create FMA"); @@ -2308,6 +2447,24 @@ return rewriter.create(loc, mul, res); } + /// Entry point for non-channeled convolution: + /// {{w + kw}, {kw}, {w}} + FailureOr generateConv() { + AffineExpr w, kw; + bindDims(ctx, w, kw); + if (!iters({Par(), Red()})) + return rewriter.notifyMatchFailure(op, + "failed to match conv::W 1-par 1-red"); + + // No transposition needed. + if (layout({/*lhsIndex*/ {w + kw}, + /*rhsIndex*/ {kw}, + /*resIndex*/ {w}})) + return conv(Conv1DOpOrder::W); + + return rewriter.notifyMatchFailure(op, "not a conv::W layout"); + } + /// Entry point that transposes into the common form: /// {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}} FailureOr generateNwcConv() { @@ -2468,7 +2625,10 @@ auto stride = strides ? *strides.getValues().begin() : 1; auto dilation = dilations ? *dilations.getValues().begin() : 1; Conv1DGenerator e(rewriter, op, stride, dilation); - auto res = e.generateNwcConv(); + auto res = e.generateConv(); + if (succeeded(res)) + return res; + res = e.generateNwcConv(); if (succeeded(res)) return res; res = e.generateNcwConv(); Index: mlir/test/Dialect/Linalg/vectorize-convolution.mlir =================================================================== --- mlir/test/Dialect/Linalg/vectorize-convolution.mlir +++ mlir/test/Dialect/Linalg/vectorize-convolution.mlir @@ -460,6 +460,58 @@ // Write the result back in one shot. // CHECK: vector.transfer_write %[[RES]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]] +// ----- + +func.func @conv1d_8_tensor(%input: tensor<11xf32>, %filter: tensor<4xf32>, %output: tensor<8xf32>) -> tensor<8xf32> { + %0 = linalg.conv_1d ins(%input, %filter : tensor<11xf32>, tensor<4xf32>) + outs(%output : tensor<8xf32>) -> tensor<8xf32> + return %0 : tensor<8xf32> +} + +// CHECK: func @conv1d_8_tensor +// CHECK-SAME: (%[[INPUT:.+]]: tensor<11xf32>, %[[FILTER:.+]]: tensor<4xf32>, %[[OUTPUT:.+]]: tensor<8xf32>) + +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[F0:.+]] = arith.constant 0.000000e+00 : f32 + +/// Read the whole data in one shot. +// CHECK-DAG: %[[V_INPUT_R:.+]] = vector.transfer_read %[[INPUT]][%[[C0]]], %[[F0]] +// CHECK-DAG: %[[V_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]]], %[[F0]] +// CHECK-DAG: %[[V_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]]], %[[F0]] + +// CHECK: %[[V_INPUT_0:.+]] = vector.extract_strided_slice %[[V_INPUT_R]] +// CHECK-SAME: {offsets = [0], sizes = [8], strides = [1]} : vector<11xf32> to vector<8xf32> +// CHECK: %[[V_INPUT_1:.+]] = vector.extract_strided_slice %[[V_INPUT_R]] +// CHECK-SAME: {offsets = [1], sizes = [8], strides = [1]} : vector<11xf32> to vector<8xf32> +// CHECK: %[[V_INPUT_2:.+]] = vector.extract_strided_slice %[[V_INPUT_R]] +// CHECK-SAME: {offsets = [2], sizes = [8], strides = [1]} : vector<11xf32> to vector<8xf32> +// CHECK: %[[V_INPUT_3:.+]] = vector.extract_strided_slice %[[V_INPUT_R]] +// CHECK-SAME: {offsets = [3], sizes = [8], strides = [1]} : vector<11xf32> to vector<8xf32> + +// CHECK: %[[V_FILTER_0:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<4xf32> +// CHECK: %[[V_FILTER_1:.+]] = vector.extract %[[V_FILTER_R]][1] : vector<4xf32> +// CHECK: %[[V_FILTER_2:.+]] = vector.extract %[[V_FILTER_R]][2] : vector<4xf32> +// CHECK: %[[V_FILTER_3:.+]] = vector.extract %[[V_FILTER_R]][3] : vector<4xf32> + +/// w == 0, kw == 0 +// CHECK: %[[RES_0:.+]] = vector.outerproduct +// CHECK-SAME: %[[V_INPUT_0]], %[[V_FILTER_0]], %[[V_OUTPUT_R]] {kind = #vector.kind} +// CHECK-SAME: : vector<8xf32>, f32 +/// w == 1, kw == 1 +// CHECK: %[[RES_1:.+]] = vector.outerproduct +// CHECK-SAME: %[[V_INPUT_1]], %[[V_FILTER_1]], %[[RES_0]] {kind = #vector.kind} +// CHECK-SAME: : vector<8xf32>, f32 +/// w == 2, kw == 2 +// CHECK: %[[RES_2:.+]] = vector.outerproduct +// CHECK-SAME: %[[V_INPUT_2]], %[[V_FILTER_2]], %[[RES_1]] {kind = #vector.kind} +// CHECK-SAME: : vector<8xf32>, f32 +/// w == 3, kw == 3 +// CHECK: %[[RES_3:.+]] = vector.outerproduct +// CHECK-SAME: %[[V_INPUT_3]], %[[V_FILTER_3]], %[[RES_2]] {kind = #vector.kind} +// CHECK-SAME: : vector<8xf32>, f32 + +// Write the result back in one shot. +// CHECK: vector.transfer_write %[[RES_3]], %[[OUTPUT]][%[[C0]]] // -----