diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -1454,18 +1454,28 @@ //===----------------------------------------------------------------------===// // Convolution vectorization patterns //===----------------------------------------------------------------------===// + +/// Return true if the given `linalgOp`'s region is mul + add for convolution. +static bool hasMulAddRegionForConv(LinalgOp linalgOp) { + // Check for reduction `add` preceded by `mul`. + Operation *reduceOp = matchLinalgReduction(linalgOp.getOutputOperand(0)); + if (!reduceOp) + return false; + llvm::Optional maybeKind; + maybeKind = getKindForOp(reduceOp); + if (!maybeKind || *maybeKind != vector::CombiningKind::ADD) + return false; + maybeKind = getKindForOp(&(linalgOp->getRegion(0).front().front())); + if (!maybeKind || *maybeKind != vector::CombiningKind::MUL) + return false; + return true; +} + namespace { -/// Generate a vector implementation for: -/// ``` -/// Op def: ( n, w, c, kw, f ) -/// Iters: ({Par(), Par(), Par(), Red(), Red()}) -/// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}} -/// ``` -/// w and kw are unrolled. -/// TODO: do not unroll w (resp. kw) when the strideW ( resp. dilationW) is > 1. -struct Conv1D_NWC_WCF_Generator : public StructuredGenerator { - Conv1D_NWC_WCF_Generator(OpBuilder &builder, LinalgOp linalgOp, int strideW, - int dilationW) +/// Generate vector implementations for linalg ops that are 1-D convolution ops. +struct Conv1DGenerator : public StructuredGenerator { + Conv1DGenerator(OpBuilder &builder, LinalgOp linalgOp, int strideW, + int dilationW) : StructuredGenerator(builder, linalgOp), valid(false), strideW(strideW), dilationW(dilationW) { // Determine whether `linalgOp` can be generated with this generator @@ -1483,16 +1493,7 @@ resShapedType.getRank() != 3) return; - // Check for reduction `add` preceded by `mul`. - Operation *reduceOp = matchLinalgReduction(linalgOp.getOutputOperand(0)); - if (!reduceOp) - return; - llvm::Optional maybeKind; - maybeKind = getKindForOp(reduceOp); - if (!maybeKind || *maybeKind != vector::CombiningKind::ADD) - return; - maybeKind = getKindForOp(&(linalgOp->getRegion(0).front().front())); - if (!maybeKind || *maybeKind != vector::CombiningKind::MUL) + if (!hasMulAddRegionForConv(linalgOp)) return; // The op is now known to be valid. @@ -1577,6 +1578,9 @@ /// Entry point that transposes into the common form: /// {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}} FailureOr generateConv() { + if (!valid) + return failure(); + AffineExpr n, w, f, kw, c; bindDims(ctx, n, w, f, kw, c); @@ -1584,9 +1588,9 @@ return failure(); // No transposition needed. - if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c}, - /*rhsIndex*/ {kw, c, f}, - /*resIndex*/ {n, w, f}})) + if (layout({/*lhsIndex=*/{n, strideW * w + dilationW * kw, c}, + /*rhsIndex=*/{kw, c, f}, + /*resIndex=*/{n, w, f}})) return conv(); return failure(); } @@ -1597,6 +1601,184 @@ Value lhsShaped, rhsShaped, resShaped; ShapedType lhsShapedType, rhsShapedType, resShapedType; }; + +/// Generate vector implementations for linalg ops that are 2-D convolution ops. +struct Conv2DGenerator : public StructuredGenerator { + Conv2DGenerator(OpBuilder &builder, LinalgOp linalgOp, int strideH, + int strideW, int dilationH, int dilationW) + : StructuredGenerator(builder, linalgOp), valid(false), + strideH(strideH), strideW(strideW), dilationH(dilationH), + dilationW(dilationW) { + // Determine whether `linalgOp` can be generated with this generator. + if (linalgOp.getNumInputs() != 2 || linalgOp.getNumOutputs() != 1) + return; + + lhsShaped = linalgOp.inputs()[0]; + rhsShaped = linalgOp.inputs()[1]; + resShaped = linalgOp.outputs()[0]; + lhsShapedType = lhsShaped.getType().dyn_cast(); + rhsShapedType = rhsShaped.getType().dyn_cast(); + resShapedType = resShaped.getType().dyn_cast(); + if (!lhsShapedType || !rhsShapedType || !resShapedType) + return; + if (lhsShapedType.getRank() != 4 || rhsShapedType.getRank() != 4 || + resShapedType.getRank() != 4) + return; + + if (!hasMulAddRegionForConv(linalgOp)) + return; + + // The op is now known to be valid. + valid = true; + } + + /// Generate a vector implementation for: + /// ``` + /// Op def: ( n, oh, ow, oc, fh, fw, ic) + /// Iters: ({Par(), Par(), Par(), Par(), Red(), Red(), Red()}) + /// Layout: {/*input=*/{n, + /// strideH * oh + dilationH * fh, + /// strideW * ow + dilationW * fw, + /// ic}, + /// /*filter=*/{fh, fw, ic, oc}, + /// /*output=*/{n, oh, ow, oc}} + /// ``` + /// Window dimensions are al lunrolled right now. + /// TODO: do not unroll window dimensions. + FailureOr conv() { + if (!valid) + return failure(); + + int nSize = resShapedType.getDimSize(0); + int ohSize = resShapedType.getDimSize(1); + int owSize = resShapedType.getDimSize(2); + int ocSize = resShapedType.getDimSize(3); + int fhSize = rhsShapedType.getDimSize(0); + int fwSize = rhsShapedType.getDimSize(1); + int icSize = rhsShapedType.getDimSize(2); + + vector::TransferWriteOp write; + Value zero = builder.create(loc, 0); + + int64_t ohSizeStep = strideH == 1 ? ohSize : 1; + int64_t owSizeStep = strideW == 1 ? owSize : 1; + + // Vector type for lhs slices per contraction. + VectorType lhsType = + VectorType::get({nSize, ohSizeStep, owSizeStep, icSize}, + lhsShapedType.getElementType()); + // Vector type for result slices per contraction. + VectorType resType = + VectorType::get({nSize, ohSizeStep, owSizeStep, ocSize}, + resShapedType.getElementType()); + + StringRef par = Par().strRef, red = Red().strRef; + SmallVector contractIterators{par, par, par, par, par, par, red}; + + AffineExpr nDim, ohDim, owDim, ocDim, fhDim, fwDim, icDim; + bindDims(ctx, nDim, ohDim, owDim, ocDim, fhDim, fwDim, icDim); + + // Unroll along fh/fw and read slices of lhs and rhs. + for (int64_t fh = 0; fh < fhSize; ++fh) { + for (int64_t fw = 0; fw < fwSize; ++fw) { + // Read rhs slice of size {1, 1, ic, oc} @ [fh, fw, 0, 0]. + Value fhVal = builder.create(loc, fh); + Value fwVal = builder.create(loc, fw); + VectorType rhsType = VectorType::get({1, 1, icSize, ocSize}, + rhsShapedType.getElementType()); + Value rhs = builder.create( + loc, rhsType, rhsShaped, ValueRange{fhVal, fwVal, zero, zero}); + + for (int64_t oh = 0; oh < ohSize; oh += ohSizeStep) { + for (int64_t ow = 0; ow < owSize; ow += owSizeStep) { + // Read lhs slice of size {n, ohSizeStep, owSizeStep, ic} + // @ [0, SH * oh + DH * fh, SW * ow + DW * fw, 0]. + Value ihVal = builder.create( + loc, strideH * oh + dilationH * fh); + Value iwVal = builder.create( + loc, strideW * ow + dilationW * fw); + Value lhs = builder.create( + loc, lhsType, lhsShaped, ValueRange{zero, ihVal, iwVal, zero}); + + // Read res slice: {n, ohSizeStep, owSizeStep, oc} @ [0, oh, ow, 0]. + Value ohVal = builder.create(loc, oh); + Value owVal = builder.create(loc, ow); + // When operating on tensors, reading from the updated value is + // required for vector.transfer_read/write hoisting to function as + // expected. + Value res = builder.create( + loc, resType, resShaped, ValueRange{zero, ohVal, owVal, zero}); + + // Compute contraction: + // I{N, 1, 1, IC} * F{1, 1, IC, OC} -> O{N, 1, 1, OC} + res = builder.create( + loc, lhs, rhs, res, + MapList{{nDim, ohDim, owDim, icDim}, + {fhDim, fwDim, icDim, ocDim}, + {nDim, ohDim, owDim, ocDim}}, + contractIterators); + + // Write back res slice: {n, wSizeStep, f} @ [0, w, 0]. + write = builder.create( + loc, res, resShaped, ValueRange{zero, ohVal, owVal, zero}); + if (write.getNumResults() == 1) + resShaped = write->getResult(0); + } + } + } + } + + return write.getOperation(); + } + + /// Entry point that transposes into the common form: + /// {/*input=*/{n, + /// strideH * oh + dilationH * fh, + /// strideW * ow + dilationW * fw, + /// ic}, + /// /*filter=*/{fh, fw, ic, oc}, + /// /*output=*/{n, oh, ow, oc}} + FailureOr generateConv() { + if (!valid) { + LLVM_DEBUG(llvm::dbgs() << "invalid for 2-D generator\n"); + return failure(); + } + + AffineExpr n, oh, ow, oc, fh, fw, ic; + bindDims(ctx, n, oh, ow, oc, fh, fw, ic); + + if (!iters({Par(), Par(), Par(), Par(), Red(), Red(), Red()})) { + LLVM_DEBUG(llvm::dbgs() << "failed to match iterators\n"); + return failure(); + } + + // No transposition needed. + if (layout({/*lhsIndex=*/{n, strideH * oh + dilationH * fh, + strideW * ow + dilationW * fw, ic}, + /*rhsIndex=*/{fh, fw, ic, oc}, + /*resIndex=*/{n, oh, ow, oc}})) + return conv(); + LLVM_DEBUG(llvm::dbgs() << "failed to match layout\n"); + LLVM_DEBUG(llvm::dbgs() << "linalg op indexing maps: ["); + llvm::interleaveComma(maps, llvm::dbgs()); + LLVM_DEBUG(llvm::dbgs() << "]\n"); + auto inferredMaps = AffineMap::inferFromExprList( + MapList{/*lhsIndex=*/{n, strideH * oh + dilationH * fh, + strideW * ow + dilationW * fw, ic}, + /*rhsIndex=*/{fh, fw, ic, oc}, + /*resIndex=*/{n, oh, ow, oc}}); + LLVM_DEBUG(llvm::dbgs() << "inferred indexing maps: ["); + llvm::interleaveComma(inferredMaps, llvm::dbgs()); + LLVM_DEBUG(llvm::dbgs() << "]\n"); + return failure(); + } + +private: + bool valid; + int strideH, strideW, dilationH, dilationW; + Value lhsShaped, rhsShaped, resShaped; + ShapedType lhsShapedType, rhsShapedType, resShapedType; +}; } // namespace /// Helper function to vectorize a `linalgOp` with convolution semantics. @@ -1606,11 +1788,45 @@ // TODO: these are legitimately part of ConvolutionOpInterface. auto strides = convOp->getAttrOfType("strides"); auto dilations = convOp->getAttrOfType("dilations"); - auto stride = strides ? *strides.getValues().begin() : 1; - auto dilation = dilations ? *dilations.getValues().begin() : 1; + LinalgOp linalgOp = cast(convOp.getOperation()); - Conv1D_NWC_WCF_Generator e(b, linalgOp, stride, dilation); - return e.generateConv(); + + { // 1-D case + auto stride = strides ? *strides.getValues().begin() : 1; + auto dilation = dilations ? *dilations.getValues().begin() : 1; + + Conv1DGenerator e(b, linalgOp, stride, dilation); + auto conv = e.generateConv(); + if (succeeded(conv)) + return conv; + } + + auto is2DAttr = [](DenseIntElementsAttr attr) { + auto type = attr.getType().dyn_cast(); + return type && type.getRank() == 1 && type.getDimSize(0) == 2; + }; + + // 2-D case + if ((!strides || is2DAttr(strides)) && (!dilations || is2DAttr(dilations))) { + uint64_t strideH = 1, strideW = 1; + if (strides && is2DAttr(strides)) { + strideH = strides.getValue(0); + strideW = strides.getValue(1); + } + + uint64_t dilationH = 1, dilationW = 1; + if (dilations && is2DAttr(dilations)) { + dilationH = dilations.getValue(0); + dilationW = dilations.getValue(1); + } + + Conv2DGenerator e(b, linalgOp, strideH, strideW, dilationH, dilationW); + auto conv = e.generateConv(); + if (succeeded(conv)) + return conv; + } + + return failure(); } struct VectorizeConvolution diff --git a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir --- a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir +++ b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir @@ -109,8 +109,6 @@ // ----- - - // CHECK: #[[INPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> // CHECK: #[[FILTER_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)> // CHECK: #[[OUTPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> @@ -150,3 +148,142 @@ outs(%output : memref<4x2x8xf32>) return } + +// ----- + +// 2-D, memref, 1x1 filter, non-1 strides + +// CHECK: #[[INPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d6)> +// CHECK: #[[FILTER_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)> +// CHECK: #[[OUTPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> + +// CHECK: func @conv2d_nhwc_4x2x2x8_memref +// CHECK-SAME: (%[[INPUT:.+]]: memref<4x3x6x3xf32>, %[[FILTER:.+]]: memref<1x1x3x8xf32>, %[[OUTPUT:.+]]: memref<4x2x2x8xf32>) +func @conv2d_nhwc_4x2x2x8_memref(%input: memref<4x3x6x3xf32>, %filter: memref<1x1x3x8xf32>, %output: memref<4x2x2x8xf32>) { + linalg.conv_2d_nhwc_hwcf + {dilations = dense<1> : tensor<2xi64>, strides = dense<[2, 3]> : tensor<2xi64>} + ins(%input, %filter : memref<4x3x6x3xf32>, memref<1x1x3x8xf32>) + outs(%output : memref<4x2x2x8xf32>) + return +} + +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index +// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index +// CHECK-DAG: %[[F0:.+]] = arith.constant 0.000000e+00 : f32 + +// CHECK: %[[V_FILTER:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]], %[[F0]] {in_bounds = [true, true, true, true]} : memref<1x1x3x8xf32>, vector<1x1x3x8xf32> + +/// oh == 0, ow == 0 +// CHECK: %[[V_INPUT_00:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]], %[[F0]] {in_bounds = [true, true, true, true]} : memref<4x3x6x3xf32>, vector<4x1x1x3xf32> +// CHECK: %[[V_OUTPUT_00:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]], %[[F0]] {in_bounds = [true, true, true, true]} : memref<4x2x2x8xf32>, vector<4x1x1x8xf32> +// CHECK: %[[CONTRACT0:.+]] = vector.contract { +// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "reduction"] +// CHECK-SAME: } %[[V_INPUT_00]], %[[V_FILTER]], %[[V_OUTPUT_00]] : vector<4x1x1x3xf32>, vector<1x1x3x8xf32> into vector<4x1x1x8xf32> +// CHECK: vector.transfer_write %[[CONTRACT0]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true, true]} : vector<4x1x1x8xf32>, memref<4x2x2x8xf32> + +/// oh == 0, ow == 1 +// CHECK: %[[V_INPUT_03:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C3]], %[[C0]]], %[[F0]] +// CHECK: %[[V_OUTPUT_01:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C1]], %[[C0]]], %[[F0]] +// CHECK: %[[CONTRACT1:.+]] = vector.contract { +// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "reduction"] +// CHECK-SAME: } %[[V_INPUT_03]], %[[V_FILTER]], %[[V_OUTPUT_01]] +// CHECK: vector.transfer_write %[[CONTRACT1]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C1]], %[[C0]]] + +/// oh == 1, ow == 0 +// CHECK: %[[V_INPUT_20:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C2]], %[[C0]], %[[C0]]], %[[F0]] +// CHECK: %[[V_OUTPUT_10:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C1]], %[[C0]], %[[C0]]], %[[F0]] +// CHECK: %[[CONTRACT2:.+]] = vector.contract { +// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "reduction"] +// CHECK-SAME: } %[[V_INPUT_20]], %[[V_FILTER]], %[[V_OUTPUT_10]] +// CHECK: vector.transfer_write %[[CONTRACT2]], %[[OUTPUT]][%[[C0]], %[[C1]], %[[C0]], %[[C0]]] + +/// oh == 1, ow == 1 +// CHECK: %[[V_INPUT_23:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C2]], %[[C3]], %[[C0]]], %[[F0]] +// CHECK: %[[V_OUTPUT_11:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C1]], %[[C1]], %[[C0]]], %[[F0]] +// CHECK: %[[CONTRACT3:.+]] = vector.contract { +// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "reduction"] +// CHECK-SAME: } %[[V_INPUT_23]], %[[V_FILTER]], %[[V_OUTPUT_11]] +// CHECK: vector.transfer_write %[[CONTRACT3]], %[[OUTPUT]][%[[C0]], %[[C1]], %[[C1]], %[[C0]]] + +// ----- + +// 2-D, tensor, non-1 filter, non-1 dilations + +// CHECK: #[[INPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d6)> +// CHECK: #[[FILTER_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)> +// CHECK: #[[OUTPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> + +// CHECK: func @conv2d_nhwc_4x1x2x8_tensor +// CHECK-SAME: (%[[INPUT:.+]]: tensor<4x3x5x3xf32>, %[[FILTER:.+]]: tensor<2x2x3x8xf32>, %[[INIT:.+]]: tensor<4x1x2x8xf32>) +func @conv2d_nhwc_4x1x2x8_tensor(%input: tensor<4x3x5x3xf32>, %filter: tensor<2x2x3x8xf32>, %init: tensor<4x1x2x8xf32>) -> tensor<4x1x2x8xf32> { + %0 = linalg.conv_2d_nhwc_hwcf + {dilations = dense<[2, 3]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} + ins(%input, %filter : tensor<4x3x5x3xf32>, tensor<2x2x3x8xf32>) + outs(%init : tensor<4x1x2x8xf32>) -> tensor<4x1x2x8xf32> + return %0 : tensor<4x1x2x8xf32> +} + +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index +// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index +// CHECK-DAG: %[[F0:.+]] = arith.constant 0.000000e+00 : f32 + +// fh == 0, fw == 0 +// CHECK: %[[V_FILTER_00:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]], %[[F0]] {in_bounds = [true, true, true, true]} : tensor<2x2x3x8xf32>, vector<1x1x3x8xf32> +// CHECK: %[[V_INPUT_00:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]], %[[F0]] {in_bounds = [true, true, true, true]} : tensor<4x3x5x3xf32>, vector<4x1x2x3xf32> +// CHECK: %[[V_OUTPUT_00:.+]] = vector.transfer_read %[[INIT]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]], %[[F0]] {in_bounds = [true, true, true, true]} : tensor<4x1x2x8xf32>, vector<4x1x2x8xf32> +// CHECK: %[[CONTRACT0:.+]] = vector.contract { +// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "reduction"] +// CHECK-SAME: } %[[V_INPUT_00]], %[[V_FILTER_00]], %[[V_OUTPUT_00]] : vector<4x1x2x3xf32>, vector<1x1x3x8xf32> into vector<4x1x2x8xf32> +// CHECK: %[[WRITE0:.+]] = vector.transfer_write %[[CONTRACT0]], %[[INIT]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true, true]} : vector<4x1x2x8xf32>, tensor<4x1x2x8xf32> + +// fh == 0, fw == 1 +// CHECK: %[[V_FILTER_01:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C1]], %[[C0]], %[[C0]]], %[[F0]] +// CHECK: %[[V_INPUT_03:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C3]], %[[C0]]], %[[F0]] +// CHECK: %[[CONTRACT1:.+]] = vector.contract { +// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "reduction"] +// CHECK-SAME: } %[[V_INPUT_03]], %[[V_FILTER_01]], %[[CONTRACT0]] +// CHECK: %[[WRITE1:.+]] = vector.transfer_write %[[CONTRACT1]], %[[WRITE0]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] + +// fh == 1, fw == 0 +// CHECK: %[[V_FILTER_10:.+]] = vector.transfer_read %[[FILTER]][%[[C1]], %[[C0]], %[[C0]], %[[C0]]], %[[F0]] +// CHECK: %[[V_INPUT_20:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C2]], %[[C0]], %[[C0]]], %[[F0]] +// CHECK: %[[CONTRACT2:.+]] = vector.contract { +// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "reduction"] +// CHECK-SAME: } %[[V_INPUT_20]], %[[V_FILTER_10]], %[[CONTRACT1]] +// CHECK: %[[WRITE2:.+]] = vector.transfer_write %[[CONTRACT2]], %[[WRITE1]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] + +// fh == 1, fw == 1 +// CHECK: %[[V_FILTER_11:.+]] = vector.transfer_read %[[FILTER]][%[[C1]], %[[C1]], %[[C0]], %[[C0]]], %[[F0]] +// CHECK: %[[V_INPUT_23:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C2]], %[[C3]], %[[C0]]], %[[F0]] +// CHECK: %[[CONTRACT3:.+]] = vector.contract { +// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "reduction"] +// CHECK-SAME: } %[[V_INPUT_23]], %[[V_FILTER_11]], %[[CONTRACT2]] +// CHECK: %[[WRITE3:.+]] = vector.transfer_write %[[CONTRACT3]], %[[WRITE2]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] + +// CHECK: return %[[WRITE3]] + +// ----- + +// TODO: add support for linalg.conv_2d_nchw_fchw + +// CHECK-LABEL: func @conv2d_hchw_4x8x2x2_tensor +func @conv2d_hchw_4x8x2x2_tensor(%input: tensor<4x3x3x6xf32>, %filter: tensor<8x3x1x1xf32>, %init: tensor<4x8x2x2xf32>) -> tensor<4x8x2x2xf32> { + // CHECK: linalg.conv_2d_nchw_fchw + %0 = linalg.conv_2d_nchw_fchw + {dilations = dense<1> : tensor<2xi64>, strides = dense<[2,3]> : tensor<2xi64>} + ins(%input, %filter : tensor<4x3x3x6xf32>, tensor<8x3x1x1xf32>) + outs(%init : tensor<4x8x2x2xf32>) -> tensor<4x8x2x2xf32> + return %0: tensor<4x8x2x2xf32> +}