diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -91,6 +91,12 @@ return res; } +/// Helper enum to represent conv1d input traversal order. +enum class Conv1DOpOrder { + Ncw, // Corresponds to operation that traverses the input in (n, c, w) order. + Nwc // Corresponds to operation that traverses the input in (n, w, c) order. +}; + /// Helper data structure to represent the result of vectorization. /// In certain specific cases, like terminators, we do not want to propagate/ enum VectorizationStatus { @@ -1312,14 +1318,23 @@ /// or /// /// ``` +/// Op def: ( n, c, w, f, kw ) +/// Iters: ({Par(), Par(), Par(), Red(), Red()}) +/// Layout: {{n, c, strideW * w + dilationW * kw}, {f, c, kw}, {n, f, w}} +/// ``` +/// kw is unrolled, w is unrolled iff dilationW > 1. +/// +/// or +/// +/// ``` /// Op def: ( n, w, c, kw ) /// Iters: ({Par(), Par(), Par(), Red()}) /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}} /// ``` /// kw is unrolled, w is unrolled iff dilationW > 1. -struct Conv1DNwcGenerator : public StructuredGenerator { - Conv1DNwcGenerator(OpBuilder &builder, LinalgOp linalgOp, int strideW, - int dilationW) +struct Conv1DGenerator : public StructuredGenerator { + Conv1DGenerator(OpBuilder &builder, LinalgOp linalgOp, int strideW, + int dilationW) : StructuredGenerator(builder, linalgOp), strideW(strideW), dilationW(dilationW) { // Determine whether `linalgOp` can be generated with this generator @@ -1382,15 +1397,45 @@ /// kw is always unrolled. /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is /// > 1. - FailureOr conv() { + FailureOr conv(Conv1DOpOrder conv1DOpOrder) { if (!valid) return failure(); int64_t nSize, wSize, cSize, kwSize, fSize; - // kernel{kw, c, f} - bindShapeDims(rhsShapedType, kwSize, cSize, fSize); - // out{n, w, f} - bindShapeDims(resShapedType, nSize, wSize); + SmallVector lhsShape, rhsShape, resShape; + switch (conv1DOpOrder) { + case Conv1DOpOrder::Nwc: + // kernel{kw, c, f} + bindShapeDims(rhsShapedType, kwSize, cSize, fSize); + // out{n, w, f} + bindShapeDims(resShapedType, nSize, wSize); + lhsShape = {nSize, + // iw = ow * sw + kw * dw - 1 + // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14) + // Perform the proper inclusive -> exclusive -> inclusive. + ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - + 1, + cSize}; + rhsShape = {kwSize, cSize, fSize}; + resShape = {nSize, wSize, fSize}; + break; + case Conv1DOpOrder::Ncw: + // kernel{f, c, kw} + bindShapeDims(rhsShapedType, fSize, cSize, kwSize); + // out{n, f, w} + bindShapeDims(resShapedType, nSize, fSize, wSize); + lhsShape = {nSize, cSize, + // iw = ow * sw + kw * dw - 1 + // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14) + // Perform the proper inclusive -> exclusive -> inclusive. + ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - + 1}; + rhsShape = {fSize, cSize, kwSize}; + resShape = {nSize, fSize, wSize}; + break; + default: + return failure(); + } vector::TransferWriteOp write; Value zero = builder.create(loc, 0); @@ -1403,17 +1448,9 @@ Type lhsEltType = lhsShapedType.getElementType(); Type rhsEltType = rhsShapedType.getElementType(); Type resEltType = resShapedType.getElementType(); - VectorType lhsType = VectorType::get( - {nSize, - // iw = ow * sw + kw * dw - 1 - // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14) - // Perform the proper inclusive -> exclusive -> inclusive. - ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1, - cSize}, - lhsEltType); - VectorType rhsType = VectorType::get({kwSize, cSize, fSize}, rhsEltType); - VectorType resType = VectorType::get({nSize, wSize, fSize}, resEltType); - + auto lhsType = VectorType::get(lhsShape, lhsEltType); + auto rhsType = VectorType::get(rhsShape, rhsEltType); + auto resType = VectorType::get(resShape, resEltType); // Read lhs slice of size {w * strideW + kw * dilationW, c, f} @ [0, 0, // 0]. Value lhs = builder.create( @@ -1425,6 +1462,29 @@ Value res = builder.create( loc, resType, resShaped, ValueRange{zero, zero, zero}); + // The base vectorization case is input: {n,w,c}, weight: {kw,c,f}, output: + // {n,w,f}. To reuse the base pattern vectorization case, we do pre + // transpose on input, weight, and output. + switch (conv1DOpOrder) { + case Conv1DOpOrder::Nwc: + // Base case, so no transposes necessary. + break; + case Conv1DOpOrder::Ncw: + // To match base vectorization case, we pre-transpose current case. + // ncw -> nwc + static constexpr std::array permLhs = {0, 2, 1}; + lhs = builder.create(loc, lhs, permLhs); + // fcw -> wcf + static constexpr std::array permRhs = {2, 1, 0}; + rhs = builder.create(loc, rhs, permRhs); + // nfw -> nwf + static constexpr std::array permRes = {0, 2, 1}; + res = builder.create(loc, res, permRes); + break; + default: + return failure(); + } + //===------------------------------------------------------------------===// // Begin vector-only rewrite part //===------------------------------------------------------------------===// @@ -1478,6 +1538,22 @@ // End vector-only rewrite part //===------------------------------------------------------------------===// + // The base vectorization case is output: {n,w,f} + // To reuse the result from base pattern vectorization case, we post + // transpose the base case result. + switch (conv1DOpOrder) { + case Conv1DOpOrder::Nwc: + // Base case, so no transposes necessary. + break; + case Conv1DOpOrder::Ncw: + // nwf -> nfw + static constexpr std::array perm = {0, 2, 1}; + res = builder.create(loc, res, perm); + break; + default: + return failure(); + } + // Write back res slice of size {n, w, f} @ [0, 0, 0]. return builder .create(loc, res, resShaped, @@ -1619,7 +1695,7 @@ /// Entry point that transposes into the common form: /// {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}} - FailureOr generateConv() { + FailureOr generateNwcConv() { AffineExpr n, w, f, kw, c; bindDims(ctx, n, w, f, kw, c); if (!iters({Par(), Par(), Par(), Red(), Red()})) @@ -1629,7 +1705,23 @@ if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c}, /*rhsIndex*/ {kw, c, f}, /*resIndex*/ {n, w, f}})) - return conv(); + return conv(Conv1DOpOrder::Nwc); + return failure(); + } + + /// Entry point that transposes into the common form: + /// {{n, c, strideW * w + dilationW * kw}, {f, c, kw}, {n, f, w}} + FailureOr generateNcwConv() { + AffineExpr n, w, f, kw, c; + bindDims(ctx, n, f, w, c, kw); + if (!iters({Par(), Par(), Par(), Red(), Red()})) + return failure(); + + if (layout({/*lhsIndex*/ {n, c, strideW * w + dilationW * kw}, + /*rhsIndex*/ {f, c, kw}, + /*resIndex*/ {n, f, w}})) + return conv(Conv1DOpOrder::Ncw); + return failure(); } @@ -1668,8 +1760,11 @@ auto dilations = op->getAttrOfType("dilations"); auto stride = strides ? *strides.getValues().begin() : 1; auto dilation = dilations ? *dilations.getValues().begin() : 1; - Conv1DNwcGenerator e(b, op, stride, dilation); - auto res = e.generateConv(); + Conv1DGenerator e(b, op, stride, dilation); + auto res = e.generateNwcConv(); + if (succeeded(res)) + return res; + res = e.generateNcwConv(); if (succeeded(res)) return res; return e.generateDilatedConv(); diff --git a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir --- a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir +++ b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir @@ -185,6 +185,218 @@ // Write the result back in one shot. // CHECK: vector.transfer_write %[[CONTRACT_1]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]] +// ----- + +func.func @conv1d_ncw_4x8x2_memref(%input: memref<4x3x6xf32>, %filter: memref<8x3x1xf32>, %output: memref<4x8x2xf32>) { + linalg.conv_1d_ncw_fcw + {dilations = dense<1> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>} + ins(%input, %filter : memref<4x3x6xf32>, memref<8x3x1xf32>) + outs(%output : memref<4x8x2xf32>) + return +} + +// CHECK: #[[INPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +// CHECK: #[[FILTER_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)> +// CHECK: #[[OUTPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + +// CHECK: func @conv1d_ncw_4x8x2_memref +// CHECK-SAME: (%[[INPUT:.+]]: memref<4x3x6xf32>, %[[FILTER:.+]]: memref<8x3x1xf32>, %[[OUTPUT:.+]]: memref<4x8x2xf32>) + +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[F0:.+]] = arith.constant 0.000000e+00 : f32 + +/// Read the whole data in one shot. +// CHECK-DAG: %[[V_NWC_INPUT_R:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]] +// CHECK-DAG: %[[V_NWC_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]] +// CHECK-DAG: %[[V_NWC_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]] + +/// Transpose result to nwc format. +// CHECK-DAG: %[[V_INPUT_R:.+]] = vector.transpose %[[V_NWC_INPUT_R]], [0, 2, 1] +// CHECK-DAG: %[[V_FILTER_R:.+]] = vector.transpose %[[V_NWC_FILTER_R]], [2, 1, 0] +// CHECK-DAG: %[[V_OUTPUT_R:.+]] = vector.transpose %[[V_NWC_OUTPUT_R]], [0, 2, 1] + +// CHECK: %[[V_INPUT_0:.+]] = vector.extract_strided_slice %[[V_INPUT_R]] +// CHECK-SAME: {offsets = [0, 0, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x4x3xf32> to vector<4x1x3xf32> +// CHECK: %[[V_INPUT_1:.+]] = vector.extract_strided_slice %[[V_INPUT_R]] +// CHECK-SAME: {offsets = [0, 3, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x4x3xf32> to vector<4x1x3xf32> + +// CHECK: %[[V_FILTER:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<1x3x8xf32> + +// CHECK: %[[V_OUTPUT_0:.+]] = vector.extract_strided_slice %[[V_OUTPUT_R]] +// CHECK-SAME: {offsets = [0, 0, 0], sizes = [4, 1, 8], strides = [1, 1, 1]} : vector<4x2x8xf32> to vector<4x1x8xf32> +// CHECK: %[[V_OUTPUT_1:.+]] = vector.extract_strided_slice %[[V_OUTPUT_R]] +// CHECK-SAME: {offsets = [0, 1, 0], sizes = [4, 1, 8], strides = [1, 1, 1]} : vector<4x2x8xf32> to vector<4x1x8xf32> + +/// w == 0, kw == 0 +// CHECK: %[[CONTRACT_0:.+]] = vector.contract { +// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]], +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"] +// CHECK-SAME: %[[V_INPUT_0]], %[[V_FILTER]], %[[V_OUTPUT_0]] +// CHECK-SAME: : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32> + +/// w == 1, kw == 0 +// CHECK: %[[CONTRACT_1:.+]] = vector.contract { +// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]], +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"] +// CHECK-SAME: %[[V_INPUT_1]], %[[V_FILTER]], %[[V_OUTPUT_1]] +// CHECK-SAME: : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32> + +/// w == 0, kw == 0 +// CHECK: %[[RES_0:.+]] = vector.insert_strided_slice %[[CONTRACT_0]], %[[V_OUTPUT_R]] +// CHECK-SAME: {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<4x1x8xf32> into vector<4x2x8xf32> +/// w == 1, kw == 0 +// CHECK: %[[RES_1:.+]] = vector.insert_strided_slice %[[CONTRACT_1]], %[[RES_0]] +// CHECK-SAME: {offsets = [0, 1, 0], strides = [1, 1, 1]} : vector<4x1x8xf32> into vector<4x2x8xf32> + +/// Transpose result to ncw format. +// CHECK: %[[RES_2:.+]] = vector.transpose %[[RES_1]], [0, 2, 1] + +// Write the result back in one shot. +// CHECK: vector.transfer_write %[[RES_2]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]] + +// ----- + +func.func @conv1d_ncw_4x8x2_memref(%input: memref<4x3x6xf32>, %filter: memref<8x3x2xf32>, %output: memref<4x8x2xf32>) { + linalg.conv_1d_ncw_fcw + {dilations = dense<2> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>} + ins(%input, %filter : memref<4x3x6xf32>, memref<8x3x2xf32>) + outs(%output : memref<4x8x2xf32>) + return +} + +// CHECK: #[[INPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +// CHECK: #[[FILTER_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)> +// CHECK: #[[OUTPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + +// CHECK: func @conv1d_ncw_4x8x2_memref +// CHECK-SAME: (%[[INPUT:.+]]: memref<4x3x6xf32>, %[[FILTER:.+]]: memref<8x3x2xf32>, %[[OUTPUT:.+]]: memref<4x8x2xf32>) + +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[F0:.+]] = arith.constant 0.000000e+00 : f32 + +/// Read the whole data in one shot. +// CHECK-DAG: %[[V_NWC_INPUT_R:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]] +// CHECK-DAG: %[[V_NWC_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]] +// CHECK-DAG: %[[V_NWC_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]] + +/// Transpose result to nwc format. +// CHECK-DAG: %[[V_INPUT_R:.+]] = vector.transpose %[[V_NWC_INPUT_R]], [0, 2, 1] +// CHECK-DAG: %[[V_FILTER_R:.+]] = vector.transpose %[[V_NWC_FILTER_R]], [2, 1, 0] +// CHECK-DAG: %[[V_OUTPUT_R:.+]] = vector.transpose %[[V_NWC_OUTPUT_R]], [0, 2, 1] + +// CHECK: %[[V_INPUT_0:.+]] = vector.extract_strided_slice %[[V_INPUT_R]] +// CHECK-SAME: {offsets = [0, 0, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x6x3xf32> to vector<4x1x3xf32> +// CHECK: %[[V_INPUT_1:.+]] = vector.extract_strided_slice %[[V_INPUT_R]] +// CHECK-SAME: {offsets = [0, 3, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x6x3xf32> to vector<4x1x3xf32> +// CHECK: %[[V_INPUT_2:.+]] = vector.extract_strided_slice %[[V_INPUT_R]] +// CHECK-SAME: {offsets = [0, 2, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x6x3xf32> to vector<4x1x3xf32> +// CHECK: %[[V_INPUT_3:.+]] = vector.extract_strided_slice %[[V_INPUT_R]] +// CHECK-SAME: {offsets = [0, 5, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x6x3xf32> to vector<4x1x3xf32> + +// CHECK: %[[V_FILTER_0:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<2x3x8xf32> +// CHECK: %[[V_FILTER_1:.+]] = vector.extract %[[V_FILTER_R]][1] : vector<2x3x8xf32> + +// CHECK: %[[V_OUTPUT_0:.+]] = vector.extract_strided_slice %[[V_OUTPUT_R]] +// CHECK-SAME: {offsets = [0, 0, 0], sizes = [4, 1, 8], strides = [1, 1, 1]} : vector<4x2x8xf32> to vector<4x1x8xf32> +// CHECK: %[[V_OUTPUT_1:.+]] = vector.extract_strided_slice %[[V_OUTPUT_R]] +// CHECK-SAME: {offsets = [0, 1, 0], sizes = [4, 1, 8], strides = [1, 1, 1]} : vector<4x2x8xf32> to vector<4x1x8xf32> + +/// w == 0, kw == 0 +// CHECK: %[[CONTRACT_0:.+]] = vector.contract { +// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]], +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"] +// CHECK-SAME: %[[V_INPUT_0]], %[[V_FILTER_0]], %[[V_OUTPUT_0]] +// CHECK-SAME: : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32> +/// w == 1, kw == 0 +// CHECK: %[[CONTRACT_1:.+]] = vector.contract { +// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]], +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"] +// CHECK-SAME: %[[V_INPUT_1]], %[[V_FILTER_0]], %[[V_OUTPUT_1]] +// CHECK-SAME: : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32> +/// w == 1, kw == 1 +// CHECK: %[[CONTRACT_2:.+]] = vector.contract { +// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]], +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"] +// CHECK-SAME: %[[V_INPUT_2]], %[[V_FILTER_1]], %[[CONTRACT_0]] +// CHECK-SAME: : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32> +/// w == 1, kw == 1 +// CHECK: %[[CONTRACT_3:.+]] = vector.contract { +// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]], +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"] +// CHECK-SAME: %[[V_INPUT_3]], %[[V_FILTER_1]], %[[CONTRACT_1]] +// CHECK-SAME: : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32> + +/// w == 0, kw == 0 +// CHECK: %[[RES_0:.+]] = vector.insert_strided_slice %[[CONTRACT_2]], %[[V_OUTPUT_R]] +// CHECK-SAME: {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<4x1x8xf32> into vector<4x2x8xf32> +/// w == 1, kw == 0 +// CHECK: %[[RES_1:.+]] = vector.insert_strided_slice %[[CONTRACT_3]], %[[RES_0]] +// CHECK-SAME: {offsets = [0, 1, 0], strides = [1, 1, 1]} : vector<4x1x8xf32> into vector<4x2x8xf32> + +/// Transpose result to ncw format. +// CHECK: %[[RES_2:.+]] = vector.transpose %[[RES_1]], [0, 2, 1] + +// Write the result back in one shot. +// CHECK: vector.transfer_write %[[RES_2]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]] + +// ----- + +func.func @conv1d_ncw_4x8x2_memref(%input: memref<4x3x6xf32>, %filter: memref<8x3x2xf32>, %output: memref<4x8x2xf32>) { + linalg.conv_1d_ncw_fcw + {dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} + ins(%input, %filter : memref<4x3x6xf32>, memref<8x3x2xf32>) + outs(%output : memref<4x8x2xf32>) + return +} + +// CHECK: #[[INPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +// CHECK: #[[FILTER_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)> +// CHECK: #[[OUTPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + +// CHECK: func @conv1d_ncw_4x8x2_memref +// CHECK-SAME: (%[[INPUT:.+]]: memref<4x3x6xf32>, %[[FILTER:.+]]: memref<8x3x2xf32>, %[[OUTPUT:.+]]: memref<4x8x2xf32>) + +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[F0:.+]] = arith.constant 0.000000e+00 : f32 + +/// Read the whole data in one shot. +// CHECK-DAG: %[[V_NWC_INPUT_R:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]] +// CHECK-DAG: %[[V_NWC_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]] +// CHECK-DAG: %[[V_NWC_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]] + +/// Transpose result to nwc format. +// CHECK-DAG: %[[V_INPUT_R:.+]] = vector.transpose %[[V_NWC_INPUT_R]], [0, 2, 1] +// CHECK-DAG: %[[V_FILTER_R:.+]] = vector.transpose %[[V_NWC_FILTER_R]], [2, 1, 0] +// CHECK-DAG: %[[V_OUTPUT_R:.+]] = vector.transpose %[[V_NWC_OUTPUT_R]], [0, 2, 1] + +// CHECK: %[[V_INPUT_0:.+]] = vector.extract_strided_slice %[[V_INPUT_R]] +// CHECK-SAME: {offsets = [0, 0, 0], sizes = [4, 2, 3], strides = [1, 1, 1]} : vector<4x4x3xf32> to vector<4x2x3xf32> +// CHECK: %[[V_INPUT_1:.+]] = vector.extract_strided_slice %[[V_INPUT_R]] +// CHECK-SAME: {offsets = [0, 2, 0], sizes = [4, 2, 3], strides = [1, 1, 1]} : vector<4x4x3xf32> to vector<4x2x3xf32> + +// CHECK: %[[V_FILTER_0:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<2x3x8xf32> +// CHECK: %[[V_FILTER_1:.+]] = vector.extract %[[V_FILTER_R]][1] : vector<2x3x8xf32> + +/// w == 0, kw == 0 +// CHECK: %[[CONTRACT_0:.+]] = vector.contract { +// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]], +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"] +// CHECK-SAME: %[[V_INPUT_0]], %[[V_FILTER_0]], %[[V_OUTPUT_R]] +// CHECK-SAME: : vector<4x2x3xf32>, vector<3x8xf32> into vector<4x2x8xf32> +/// w == 0, kw == 1 +// CHECK: %[[CONTRACT_1:.+]] = vector.contract { +// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]], +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"] +// CHECK-SAME: %[[V_INPUT_1]], %[[V_FILTER_1]], %[[CONTRACT_0]] +// CHECK-SAME: : vector<4x2x3xf32>, vector<3x8xf32> into vector<4x2x8xf32> + +/// Transpose result to ncw format. +// CHECK: %[[RES:.+]] = vector.transpose %[[CONTRACT_1]], [0, 2, 1] + +// Write the result back in one shot. +// CHECK: vector.transfer_write %[[RES]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]] + + // ----- func.func @depthwise_conv1d_nwc_wc_3x5x4_memref(%input: memref<3x5x4xf32>, %filter: memref<2x4xf32>, %output: memref<3x2x4xf32>) {