Index: mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp =================================================================== --- mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -334,6 +334,7 @@ /// Helper enum to represent conv1d input traversal order. enum class Conv1DOpOrder { + W, // Corresponds to non-channeled 1D convolution operation. Ncw, // Corresponds to operation that traverses the input in (n, c, w) order. Nwc // Corresponds to operation that traverses the input in (n, w, c) order. }; @@ -1823,6 +1824,15 @@ /// Generate a vector implementation for either: /// ``` +/// Op def: ( w, kw ) +/// Iters: ({Par(), Red()}) +/// Layout: {{w + kw}, {kw}, {w}} +/// ``` +/// kw is unrolled. +/// +/// or +/// +/// ``` /// Op def: ( n, w, c, kw, f ) /// Iters: ({Par(), Par(), Par(), Red(), Red()}) /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}} @@ -1863,8 +1873,10 @@ resShapedType = resShaped.getType().dyn_cast(); if (!lhsShapedType || !rhsShapedType || !resShapedType) return; - // LHS has dimension NCW/NWC and RES has dimension NFW/NCW/NWF/NWC. - if (lhsShapedType.getRank() != 3 || resShapedType.getRank() != 3) + // (LHS has dimension NCW/NWC and RES has dimension NFW/NCW/NWF/NWC) OR + // (non-channeled convolution -> LHS and RHS both have single dimensions). + if ((lhsShapedType.getRank() != 3 || resShapedType.getRank() != 3) && + (lhsShapedType.getRank() != 1 || resShapedType.getRank() != 1)) return; Operation *reduceOp = matchLinalgReduction(linalgOp.getDpsInitOperand(0)); @@ -1883,7 +1895,7 @@ auto rhsRank = rhsShapedType.getRank(); switch (oper) { case Conv: - if (rhsRank != 2 && rhsRank!= 3) + if (rhsRank != 1 && rhsRank != 2 && rhsRank != 3) return; break; case Pool: @@ -1897,6 +1909,15 @@ /// Generate a vector implementation for: /// ``` + /// Op def: ( w, kw ) + /// Iters: ({Par(), Red()}) + /// Layout: {{w + kw}, {kw}, {w}} + /// ``` + /// kw is always unrolled. + /// + /// or + /// + /// ``` /// Op def: ( n, w, c, kw, f ) /// Iters: ({Par(), Par(), Par(), Red(), Red()}) /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}} @@ -1911,6 +1932,17 @@ int64_t nSize, wSize, cSize, kwSize, fSize; SmallVector lhsShape, rhsShape, resShape; switch (conv1DOpOrder) { + case Conv1DOpOrder::W: + // out{W} + bindShapeDims(resShapedType, wSize); + // kernel{kw} + bindShapeDims(rhsShapedType, kwSize); + lhsShape = {// iw = ow + kw - 1 + // (i.e. 16 convolved with 3 -> 14) + (wSize + kwSize - 1)}; + rhsShape = {kwSize}; + resShape = {wSize}; + break; case Conv1DOpOrder::Nwc: // out{n, w, f} bindShapeDims(resShapedType, nSize, wSize, fSize); @@ -1988,24 +2020,38 @@ auto lhsType = VectorType::get(lhsShape, lhsEltType); auto rhsType = VectorType::get(rhsShape, rhsEltType); auto resType = VectorType::get(resShape, resEltType); - // Read lhs slice of size {w * strideW + kw * dilationW, c, f} @ [0, 0, - // 0]. - Value lhs = rewriter.create( - loc, lhsType, lhsShaped, ValueRange{zero, zero, zero}); - // Read rhs slice of size {kw, c, f} @ [0, 0, 0]. - // This is needed only for Conv. - Value rhs = nullptr; - if (oper == Conv) - rhs = rewriter.create( - loc, rhsType, rhsShaped, ValueRange{zero, zero, zero}); - // Read res slice of size {n, w, f} @ [0, 0, 0]. - Value res = rewriter.create( - loc, resType, resShaped, ValueRange{zero, zero, zero}); + Value lhs, rhs, res; + if (conv1DOpOrder == Conv1DOpOrder::W) { + // Read lhs slice of size {w + kw} @ [0]. + lhs = rewriter.create(loc, lhsType, lhsShaped, + ValueRange{zero}); + // Read rhs slice of size {kw} @ [0]. + rhs = rewriter.create(loc, rhsType, rhsShaped, + ValueRange{zero}); + // Read res slice of size {w} @ [0]. + res = rewriter.create(loc, resType, resShaped, + ValueRange{zero}); + } else { + // Read lhs slice of size {w * strideW + kw * dilationW, c, f} @ [0, 0, + // 0]. + lhs = rewriter.create( + loc, lhsType, lhsShaped, ValueRange{zero, zero, zero}); + // Read rhs slice of size {kw, c, f} @ [0, 0, 0]. + // This is needed only for Conv. + rhs = nullptr; + if (oper == Conv) + rhs = rewriter.create( + loc, rhsType, rhsShaped, ValueRange{zero, zero, zero}); + // Read res slice of size {n, w, f} @ [0, 0, 0]. + res = rewriter.create( + loc, resType, resShaped, ValueRange{zero, zero, zero}); + } - // The base vectorization case is input: {n,w,c}, weight: {kw,c,f}, output: - // {n,w,f}. To reuse the base pattern vectorization case, we do pre - // transpose on input, weight, and output. + // The base vectorization case for channeled convolution is input: {n,w,c}, + // weight: {kw,c,f}, output: {n,w,f}. To reuse the base pattern + // vectorization case, we do pre transpose on input, weight, and output. switch (conv1DOpOrder) { + case Conv1DOpOrder::W: case Conv1DOpOrder::Nwc: // Base case, so no transposes necessary. break; @@ -2032,30 +2078,55 @@ //===------------------------------------------------------------------===// // Unroll along kw and read slices of lhs and rhs. SmallVector lhsVals, rhsVals, resVals; - // Extract lhs slice of size {n, wSizeStep, c} @ [0, sw * w + dw * kw, 0]. - for (int64_t kw = 0; kw < kwSize; ++kw) { - for (int64_t w = 0; w < wSize; w += wSizeStep) { - lhsVals.push_back(rewriter.create( - loc, lhs, - /*offsets=*/ArrayRef{0, w * strideW + kw * dilationW, 0}, - /*sizes=*/ArrayRef{nSize, wSizeStep, cSize}, - /*strides=*/ArrayRef{1, 1, 1})); + if (conv1DOpOrder == Conv1DOpOrder::W) { + // Extract lhs slice of size {wSizeStep} @ [w + kw]. + for (int64_t kw = 0; kw < kwSize; ++kw) { + for (int64_t w = 0; w < wSize; w += wSizeStep) { + lhsVals.push_back(rewriter.create( + loc, lhs, + /*offsets=*/ArrayRef{w + kw}, + /*sizes=*/ArrayRef{wSizeStep}, + /*strides=*/ArrayRef{1})); + } + } + } else { + // Extract lhs slice of size {n, wSizeStep, c} @ [0, sw * w + dw * kw, 0]. + for (int64_t kw = 0; kw < kwSize; ++kw) { + for (int64_t w = 0; w < wSize; w += wSizeStep) { + lhsVals.push_back(rewriter.create( + loc, lhs, + /*offsets=*/ArrayRef{0, w * strideW + kw * dilationW, 0}, + /*sizes=*/ArrayRef{nSize, wSizeStep, cSize}, + /*strides=*/ArrayRef{1, 1, 1})); + } } } - // Extract rhs slice of size {c, f} @ [kw]. + // Extract rhs slice of size {c, f}/{1} @ [kw]. // Do not do for pooling. if (oper == Conv) for (int64_t kw = 0; kw < kwSize; ++kw) { rhsVals.push_back(rewriter.create( loc, rhs, /*offsets=*/ArrayRef{kw})); } - // Extract res slice: {n, wSizeStep, f} @ [0, w, 0]. - for (int64_t w = 0; w < wSize; w += wSizeStep) { - resVals.push_back(rewriter.create( - loc, res, - /*offsets=*/ArrayRef{0, w, 0}, - /*sizes=*/ArrayRef{nSize, wSizeStep, fSize}, - /*strides=*/ArrayRef{1, 1, 1})); + + if (conv1DOpOrder == Conv1DOpOrder::W) { + // Extract res slice: {wSizeStep} @ [w]. + for (int64_t w = 0; w < wSize; w += wSizeStep) { + resVals.push_back(rewriter.create( + loc, res, + /*offsets=*/ArrayRef{w}, + /*sizes=*/ArrayRef{wSizeStep}, + /*strides=*/ArrayRef{1})); + } + } else { + // Extract res slice: {n, wSizeStep, f} @ [0, w, 0]. + for (int64_t w = 0; w < wSize; w += wSizeStep) { + resVals.push_back(rewriter.create( + loc, res, + /*offsets=*/ArrayRef{0, w, 0}, + /*sizes=*/ArrayRef{nSize, wSizeStep, fSize}, + /*strides=*/ArrayRef{1, 1, 1})); + } } auto linearIndex = [&](int64_t kw, int64_t w) { @@ -2063,14 +2134,21 @@ }; // Compute contraction: O{n, w, f} += I{n, sw * w + dw * kw, c} * F{c, f} or + // perform outerproduct for non-channeled convolution or // perform simple arith operation for pooling for (int64_t kw = 0; kw < kwSize; ++kw) { for (int64_t w = 0; w < wSize; w += wSizeStep) { switch (oper) { case Conv: - resVals[w] = conv1dSliceAsContraction(rewriter, loc, - lhsVals[linearIndex(kw, w)], - rhsVals[kw], resVals[w]); + if (conv1DOpOrder == Conv1DOpOrder::W) { + resVals[w] = conv1dSliceAsOuterProduct(rewriter, loc, + lhsVals[linearIndex(kw, w)], + rhsVals[kw], resVals[w]); + } else { + resVals[w] = conv1dSliceAsContraction(rewriter, loc, + lhsVals[linearIndex(kw, w)], + rhsVals[kw], resVals[w]); + } break; case Pool: resVals[w] = pool1dSlice(rewriter, loc, lhsVals[linearIndex(kw, w)], @@ -2080,22 +2158,34 @@ } } - // Write back res slice: {n, wSizeStep, f} @ [0, w, 0]. - // This does not depend on kw. - for (int64_t w = 0; w < wSize; w += wSizeStep) { - res = rewriter.create( - loc, resVals[w], res, - /*offsets=*/ArrayRef{0, w, 0}, - /*strides=*/ArrayRef{1, 1, 1}); + if (conv1DOpOrder == Conv1DOpOrder::W) { + // Write back res slice: {wSizeStep} @ [w]. + // This does not depend on kw. + for (int64_t w = 0; w < wSize; w += wSizeStep) { + res = rewriter.create( + loc, resVals[w], res, + /*offsets=*/ArrayRef{w}, + /*strides=*/ArrayRef{1}); + } + } else { + // Write back res slice: {n, wSizeStep, f} @ [0, w, 0]. + // This does not depend on kw. + for (int64_t w = 0; w < wSize; w += wSizeStep) { + res = rewriter.create( + loc, resVals[w], res, + /*offsets=*/ArrayRef{0, w, 0}, + /*strides=*/ArrayRef{1, 1, 1}); + } } //===------------------------------------------------------------------===// // End vector-only rewrite part //===------------------------------------------------------------------===// - // The base vectorization case is output: {n,w,f} + // The base vectorization case for channeled convolution is output: {n,w,f} // To reuse the result from base pattern vectorization case, we post // transpose the base case result. switch (conv1DOpOrder) { + case Conv1DOpOrder::W: case Conv1DOpOrder::Nwc: // Base case, so no transposes necessary. break; @@ -2107,11 +2197,19 @@ } } - // Write back res slice of size {n, w, f} @ [0, 0, 0]. - return rewriter - .create(loc, res, resShaped, - ValueRange{zero, zero, zero}) - .getOperation(); + if (conv1DOpOrder == Conv1DOpOrder::W) { + // Write back res slice of size {w} @ [0]. + return rewriter + .create(loc, res, resShaped, + ValueRange{zero}) + .getOperation(); + } else { + // Write back res slice of size {n, w, f} @ [0, 0, 0]. + return rewriter + .create(loc, res, resShaped, + ValueRange{zero, zero, zero}) + .getOperation(); + } } // Create a contraction: lhs{n, w, c} * rhs{c, f} -> res{n, w, f} @@ -2127,6 +2225,14 @@ /*iteratorTypes=*/ArrayRef{par, par, par, red}); } + // Create an outerproduct: lhs{w} * rhs{1} -> res{w} for single channel + // convolution. + Value conv1dSliceAsOuterProduct(RewriterBase &rewriter, Location loc, + Value lhs, Value rhs, Value res) { + return rewriter.create(loc, res.getType(), lhs, rhs, + res, vector::CombiningKind::ADD); + } + // Create a reduction: lhs{n, w, c} -> res{n, w, c} Value pool1dSlice(RewriterBase &rewriter, Location loc, Value lhs, Value res) { @@ -2299,6 +2405,24 @@ return rewriter.create(loc, mul, res); } + /// Entry point for non-channeled convolution: + /// {{w + kw}, {kw}, {w}} + FailureOr generateConv() { + AffineExpr w, kw; + bindDims(ctx, w, kw); + if (!iters({Par(), Red()})) + return rewriter.notifyMatchFailure( + op, "failed to match conv::W 1-par 1-red"); + + // No transposition needed. + if (layout({/*lhsIndex*/ {w + kw}, + /*rhsIndex*/ {kw}, + /*resIndex*/ {w}})) + return conv(Conv1DOpOrder::W); + + return rewriter.notifyMatchFailure(op, "not a conv::W layout"); + } + /// Entry point that transposes into the common form: /// {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}} FailureOr generateNwcConv() { @@ -2459,7 +2583,10 @@ auto stride = strides ? *strides.getValues().begin() : 1; auto dilation = dilations ? *dilations.getValues().begin() : 1; Conv1DGenerator e(rewriter, op, stride, dilation); - auto res = e.generateNwcConv(); + auto res = e.generateConv(); + if (succeeded(res)) + return res; + res = e.generateNwcConv(); if (succeeded(res)) return res; res = e.generateNcwConv(); Index: mlir/test/Dialect/Linalg/vectorize-convolution.mlir =================================================================== --- mlir/test/Dialect/Linalg/vectorize-convolution.mlir +++ mlir/test/Dialect/Linalg/vectorize-convolution.mlir @@ -460,6 +460,58 @@ // Write the result back in one shot. // CHECK: vector.transfer_write %[[RES]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]] +// ----- + +func.func @conv1d_8_tensor(%input: tensor<11xf32>, %filter: tensor<4xf32>, %output: tensor<8xf32>) -> tensor<8xf32> { + %0 = linalg.conv_1d ins(%input, %filter : tensor<11xf32>, tensor<4xf32>) + outs(%output : tensor<8xf32>) -> tensor<8xf32> + return %0 : tensor<8xf32> +} + +// CHECK: func @conv1d_8_tensor +// CHECK-SAME: (%[[INPUT:.+]]: tensor<11xf32>, %[[FILTER:.+]]: tensor<4xf32>, %[[OUTPUT:.+]]: tensor<8xf32>) + +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[F0:.+]] = arith.constant 0.000000e+00 : f32 + +/// Read the whole data in one shot. +// CHECK-DAG: %[[V_INPUT_R:.+]] = vector.transfer_read %[[INPUT]][%[[C0]]], %[[F0]] +// CHECK-DAG: %[[V_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]]], %[[F0]] +// CHECK-DAG: %[[V_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]]], %[[F0]] + +// CHECK: %[[V_INPUT_0:.+]] = vector.extract_strided_slice %[[V_INPUT_R]] +// CHECK-SAME: {offsets = [0], sizes = [8], strides = [1]} : vector<11xf32> to vector<8xf32> +// CHECK: %[[V_INPUT_1:.+]] = vector.extract_strided_slice %[[V_INPUT_R]] +// CHECK-SAME: {offsets = [1], sizes = [8], strides = [1]} : vector<11xf32> to vector<8xf32> +// CHECK: %[[V_INPUT_2:.+]] = vector.extract_strided_slice %[[V_INPUT_R]] +// CHECK-SAME: {offsets = [2], sizes = [8], strides = [1]} : vector<11xf32> to vector<8xf32> +// CHECK: %[[V_INPUT_3:.+]] = vector.extract_strided_slice %[[V_INPUT_R]] +// CHECK-SAME: {offsets = [3], sizes = [8], strides = [1]} : vector<11xf32> to vector<8xf32> + +// CHECK: %[[V_FILTER_0:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<4xf32> +// CHECK: %[[V_FILTER_1:.+]] = vector.extract %[[V_FILTER_R]][1] : vector<4xf32> +// CHECK: %[[V_FILTER_2:.+]] = vector.extract %[[V_FILTER_R]][2] : vector<4xf32> +// CHECK: %[[V_FILTER_3:.+]] = vector.extract %[[V_FILTER_R]][3] : vector<4xf32> + +/// w == 0, kw == 0 +// CHECK: %[[RES_0:.+]] = vector.outerproduct +// CHECK-SAME: %[[V_INPUT_0]], %[[V_FILTER_0]], %[[V_OUTPUT_R]] {kind = #vector.kind} +// CHECK-SAME: : vector<8xf32>, f32 +/// w == 1, kw == 1 +// CHECK: %[[RES_1:.+]] = vector.outerproduct +// CHECK-SAME: %[[V_INPUT_1]], %[[V_FILTER_1]], %[[RES_0]] {kind = #vector.kind} +// CHECK-SAME: : vector<8xf32>, f32 +/// w == 2, kw == 2 +// CHECK: %[[RES_2:.+]] = vector.outerproduct +// CHECK-SAME: %[[V_INPUT_2]], %[[V_FILTER_2]], %[[RES_1]] {kind = #vector.kind} +// CHECK-SAME: : vector<8xf32>, f32 +/// w == 3, kw == 3 +// CHECK: %[[RES_3:.+]] = vector.outerproduct +// CHECK-SAME: %[[V_INPUT_3]], %[[V_FILTER_3]], %[[RES_2]] {kind = #vector.kind} +// CHECK-SAME: : vector<8xf32>, f32 + +// Write the result back in one shot. +// CHECK: vector.transfer_write %[[RES_3]], %[[OUTPUT]][%[[C0]]] // -----