Index: mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp =================================================================== --- mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -385,8 +385,10 @@ [&](auto op) { return CombiningKind::ADD; }) .Case([&](auto op) { return CombiningKind::AND; }) .Case([&](auto op) { return CombiningKind::MAXSI; }) + .Case([&](auto op) { return CombiningKind::MAXUI; }) .Case([&](auto op) { return CombiningKind::MAXF; }) .Case([&](auto op) { return CombiningKind::MINSI; }) + .Case([&](auto op) { return CombiningKind::MINUI; }) .Case([&](auto op) { return CombiningKind::MINF; }) .Case( [&](auto op) { return CombiningKind::MUL; }) @@ -1796,6 +1798,34 @@ } namespace { +bool isCastOfBlockArgument(Operation *op) { + return isa(op) && op->getNumOperands() == 1 && + op->getOperand(0).isa(); +} + +bool isBlockArgumentOrCastOfBlockArgument(Value operand) { + if (operand.isa()) + return true; + if (Operation *op = operand.getDefiningOp()) + return isCastOfBlockArgument(op); + return false; +} + +bool isPoolingOp(vector::CombiningKind kind) { + switch (kind) { + case vector::CombiningKind::ADD: + case vector::CombiningKind::MAXF: + case vector::CombiningKind::MAXSI: + case vector::CombiningKind::MAXUI: + case vector::CombiningKind::MINF: + case vector::CombiningKind::MINSI: + case vector::CombiningKind::MINUI: + return true; + default: + return false; + } +} + /// Generate a vector implementation for either: /// ``` /// Op def: ( n, w, c, kw, f ) @@ -1838,41 +1868,66 @@ resShapedType = resShaped.getType().dyn_cast(); if (!lhsShapedType || !rhsShapedType || !resShapedType) return; - if (lhsShapedType.getRank() != 3 || - (rhsShapedType.getRank() != 2 && rhsShapedType.getRank() != 3) || - resShapedType.getRank() != 3) + if (!(lhsShapedType.getRank() == 3 && resShapedType.getRank() == 3)) return; - // Check for reduction `add` preceded by `mul`. Operation *reduceOp = matchLinalgReduction(linalgOp.getDpsInitOperand(0)); if (!reduceOp) return; - std::optional maybeKind; - maybeKind = getCombinerOpKind(reduceOp); - if (!maybeKind || *maybeKind != vector::CombiningKind::ADD) - return; - // Check for single `mul` predecessor. The `mul` operands must be block - // arguments or extension of block arguments. - Operation *mulOp = nullptr; - for (Value operand : reduceOp->getOperands()) { - if (operand.isa()) - continue; - if (mulOp) - return; - mulOp = operand.getDefiningOp(); - if (!mulOp || !isa(mulOp)) + poolRedOp = reduceOp->getName().getIdentifier(); + + // If (region has 2 ops (reduction + yield) or 3 ops (extension + reduction + // + yield) and rhs is not used) then it is the body of a pooling + // If conv, check for single `mul` predecessor. The `mul` operands must be + // block arguments or extension of block arguments. + // Otherwise, check for one or zero `ext` predecessor. The `ext` operands + // must be block arguments or extension of block arguments. + auto countBlockArguments = + llvm::count_if(reduceOp->getOperands(), + [](Value v) { return v.isa(); }); + switch (countBlockArguments) { + case 1: { + // Will be convolution if feeder is a MulOp. + // Otherwise, if it can be pooling. + auto feedValIt = llvm::find_if(reduceOp->getOperands(), [](Value v) { + return !v.isa(); + }); + Operation *feedOp = (*feedValIt).getDefiningOp(); + if (isCastOfBlockArgument(feedOp)) { + oper = Pool; + isPoolExt = true; + poolExtOp = feedOp->getName().getIdentifier(); + } else if (!(isa(feedOp) && + llvm::all_of(feedOp->getOperands(), [](Value v) { + return isBlockArgumentOrCastOfBlockArgument(v); + }))) { return; + } + break; } - if (!mulOp) + case 2: + // Must be pooling + oper = Pool; + isPoolExt = false; + break; + default: return; - for (Value operand : mulOp->getOperands()) { - if (Operation *def = operand.getDefiningOp()) { - if (!isa(def)) - return; - operand = def->getOperand(0); - } - if (!operand.isa()) + } + std::optional maybeKind = + getCombinerOpKind(reduceOp); + if (!(maybeKind && (*maybeKind == vector::CombiningKind::ADD || + (oper == Pool && isPoolingOp(*maybeKind))))) { + return; + } + switch (oper) { + case Conv: + if (!(rhsShapedType.getRank() == 2 || rhsShapedType.getRank() == 3)) + return; + break; + case Pool: + if (rhsShapedType.getRank() != 1) return; + break; } // The op is now known to be valid. valid = true; @@ -1888,17 +1943,27 @@ /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is /// > 1. FailureOr conv(Conv1DOpOrder conv1DOpOrder) { - if (!valid) - return rewriter.notifyMatchFailure(op, "unvectorizable 1-D conv"); + if (!valid) { + return rewriter.notifyMatchFailure(op, "unvectorizable 1-D conv/pool"); + } int64_t nSize, wSize, cSize, kwSize, fSize; SmallVector lhsShape, rhsShape, resShape; switch (conv1DOpOrder) { case Conv1DOpOrder::Nwc: - // kernel{kw, c, f} - bindShapeDims(rhsShapedType, kwSize, cSize, fSize); // out{n, w, f} - bindShapeDims(resShapedType, nSize, wSize); + bindShapeDims(resShapedType, nSize, wSize, fSize); + switch (oper) { + case Conv: + // kernel{kw, c, f} + bindShapeDims(rhsShapedType, kwSize, cSize); + break; + case Pool: + // kernel{kw} + bindShapeDims(rhsShapedType, kwSize); + cSize = fSize; + break; + } lhsShape = {nSize, // iw = ow * sw + kw * dw - 1 // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14) @@ -1906,21 +1971,44 @@ ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1, cSize}; - rhsShape = {kwSize, cSize, fSize}; + switch (oper) { + case Conv: + rhsShape = {kwSize, cSize, fSize}; + break; + case Pool: + rhsShape = {kwSize}; + break; + } resShape = {nSize, wSize, fSize}; break; case Conv1DOpOrder::Ncw: - // kernel{f, c, kw} - bindShapeDims(rhsShapedType, fSize, cSize, kwSize); // out{n, f, w} bindShapeDims(resShapedType, nSize, fSize, wSize); + switch (oper) { + case Conv: + // kernel{f, c, kw} + bindShapeDims(rhsShapedType, fSize, cSize, kwSize); + break; + case Pool: + // kernel{kw} + bindShapeDims(rhsShapedType, kwSize); + cSize = fSize; + break; + } lhsShape = {nSize, cSize, // iw = ow * sw + kw * dw - 1 // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14) // Perform the proper inclusive -> exclusive -> inclusive. ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1}; - rhsShape = {fSize, cSize, kwSize}; + switch (oper) { + case Conv: + rhsShape = {fSize, cSize, kwSize}; + break; + case Pool: + rhsShape = {kwSize}; + break; + } resShape = {nSize, fSize, wSize}; break; } @@ -1944,8 +2032,11 @@ Value lhs = rewriter.create( loc, lhsType, lhsShaped, ValueRange{zero, zero, zero}); // Read rhs slice of size {kw, c, f} @ [0, 0, 0]. - Value rhs = rewriter.create( - loc, rhsType, rhsShaped, ValueRange{zero, zero, zero}); + Value rhs = nullptr; + // Do not do for pooling + if (oper == Conv) + rhs = rewriter.create( + loc, rhsType, rhsShaped, ValueRange{zero, zero, zero}); // Read res slice of size {n, w, f} @ [0, 0, 0]. Value res = rewriter.create( loc, resType, resShaped, ValueRange{zero, zero, zero}); @@ -1964,7 +2055,10 @@ lhs = rewriter.create(loc, lhs, permLhs); // fcw -> wcf static constexpr std::array permRhs = {2, 1, 0}; - rhs = rewriter.create(loc, rhs, permRhs); + + // Do not do for pooling + if (oper == Conv) + rhs = rewriter.create(loc, rhs, permRhs); // nfw -> nwf static constexpr std::array permRes = {0, 2, 1}; res = rewriter.create(loc, res, permRes); @@ -1988,10 +2082,12 @@ } } // Extract rhs slice of size {c, f} @ [kw]. - for (int64_t kw = 0; kw < kwSize; ++kw) { - rhsVals.push_back(rewriter.create( - loc, rhs, /*offsets=*/ArrayRef{kw})); - } + // Do not do for pooling + if (oper == Conv) + for (int64_t kw = 0; kw < kwSize; ++kw) { + rhsVals.push_back(rewriter.create( + loc, rhs, /*offsets=*/ArrayRef{kw})); + } // Extract res slice: {n, wSizeStep, f} @ [0, w, 0]. for (int64_t w = 0; w < wSize; w += wSizeStep) { resVals.push_back(rewriter.create( @@ -2005,11 +2101,21 @@ return kw * (wSize / wSizeStep) + w; }; - // Compute contraction: O{n, w, f} += I{n, sw * w + dw * kw, c} * F{c, f} + // Compute contraction: O{n, w, f} += I{n, sw * w + dw * kw, c} * F{c, f} or + // Perform simple arith operation for pooling for (int64_t kw = 0; kw < kwSize; ++kw) { for (int64_t w = 0; w < wSize; w += wSizeStep) { - resVals[w] = conv1dSliceAsContraction( - rewriter, loc, lhsVals[linearIndex(kw, w)], rhsVals[kw], resVals[w]); + switch (oper) { + case Conv: + resVals[w] = conv1dSliceAsContraction(rewriter, loc, + lhsVals[linearIndex(kw, w)], + rhsVals[kw], resVals[w]); + break; + case Pool: + resVals[w] = pool1dSlice(rewriter, loc, lhsVals[linearIndex(kw, w)], + resVals[w]); + break; + } } } @@ -2060,6 +2166,17 @@ /*iteratorTypes=*/ArrayRef{par, par, par, red}); } + // Create a reduction: lhs{n, w, c} -> res{n, w, c} + Value pool1dSlice(RewriterBase &rewriter, Location loc, Value lhs, + Value res) { + if (isPoolExt) { + lhs = rewriter.create(loc, poolExtOp, lhs, res.getType())->getResult(0); + } + return rewriter + .create(loc, poolRedOp, ArrayRef{lhs, res}, res.getType()) + ->getResult(0); + } + /// Generate a vector implementation for: /// ``` /// Op def: ( n, w, c, kw) @@ -2236,6 +2353,7 @@ /*rhsIndex*/ {kw, c, f}, /*resIndex*/ {n, w, f}})) return conv(Conv1DOpOrder::Nwc); + return rewriter.notifyMatchFailure(op, "not a conv::Nwc layout"); } @@ -2256,6 +2374,41 @@ return rewriter.notifyMatchFailure(op, "not a conv::Ncw layout"); } + /// Entry point that transposes into the common form: + /// {{n, strideW * w + dilationW * kw, c}, {kw}, {n, w, c}} for pooling + FailureOr generateNwcPooling() { + AffineExpr n, w, c, kw; + bindDims(ctx, n, w, c, kw); + if (!iters({Par(), Par(), Par(), Red()})) + return rewriter.notifyMatchFailure(op, + "failed to match pooling 3-par 1-red"); + + // No transposition needed. + if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c}, + /*rhsIndex*/ {kw}, + /*resIndex*/ {n, w, c}})) + return conv(Conv1DOpOrder::Nwc); + + return rewriter.notifyMatchFailure(op, "not a pooling::Nwc layout"); + } + + /// Entry point that transposes into the common form: + /// {{n, c, strideW * w + dilationW * kw}, {kw}, {n, c, w}} for pooling + FailureOr generateNcwPooling() { + AffineExpr n, w, c, kw; + bindDims(ctx, n, c, w, kw); + if (!iters({Par(), Par(), Par(), Red()})) + return rewriter.notifyMatchFailure(op, + "failed to match pooling 3-par 1-red"); + + if (layout({/*lhsIndex*/ {n, c, strideW * w + dilationW * kw}, + /*rhsIndex*/ {kw}, + /*resIndex*/ {n, c, w}})) + return conv(Conv1DOpOrder::Ncw); + + return rewriter.notifyMatchFailure(op, "not a pooling::Ncw layout"); + } + /// Entry point that transposes into the common form: /// {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}} FailureOr generateDilatedConv() { @@ -2275,7 +2428,12 @@ } private: + enum Oper { Conv, Pool }; bool valid = false; + Oper oper = Conv; + StringAttr poolRedOp; + StringAttr poolExtOp; + bool isPoolExt = false; int strideW, dilationW; Value lhsShaped, rhsShaped, resShaped; ShapedType lhsShapedType, rhsShapedType, resShapedType; @@ -2299,6 +2457,12 @@ if (succeeded(res)) return res; res = e.generateNcwConv(); + if (succeeded(res)) + return res; + res = e.generateNwcPooling(); + if (succeeded(res)) + return res; + res = e.generateNcwPooling(); if (succeeded(res)) return res; return e.generateDilatedConv(); Index: mlir/test/Dialect/Linalg/vectorize-convolution.mlir =================================================================== --- mlir/test/Dialect/Linalg/vectorize-convolution.mlir +++ mlir/test/Dialect/Linalg/vectorize-convolution.mlir @@ -571,3 +571,282 @@ // CHECK: %[[CONT:.*]] = vector.contract // {{.*}} %[[V_INPUT_R]], %[[V_FILTER_1]], %[[V_OUTPUT_R]] : vector<1x2x3xf16>, vector<3x2xf16> into vector<1x2x2xf32> // CHECK: vector.transfer_write %[[CONT]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]] + + + +func.func @pooling_nwc_sum_memref_1_2_1_3(%input: memref<4x4x3xf32>, %filter: memref<1xf32>, %output: memref<4x2x3xf32>) { + linalg.pooling_nwc_sum + {dilations = dense<1> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>} + ins(%input, %filter : memref<4x4x3xf32>, memref<1xf32>) + outs(%output : memref<4x2x3xf32>) + return +} + +// CHECK: func.func @pooling_nwc_sum_memref_1_2_1_3 +// CHECK-SAME: (%[[Varg0:.+]]: memref<4x4x3xf32>, %[[Varg1:.+]]: memref<1xf32>, %[[Varg2:.+]]: memref<4x2x3xf32>) +// CHECK-DAG: %[[Vc0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[Vcst:.+]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: %[[V0:.+]] = vector.transfer_read %[[Varg0]][%[[Vc0]], %[[Vc0]], %[[Vc0]]], %[[Vcst]] {in_bounds = [true, true, true]} : memref<4x4x3xf32>, vector<4x4x3xf32> +// CHECK-DAG: %[[V1:.+]] = vector.transfer_read %[[Varg2]][%[[Vc0]], %[[Vc0]], %[[Vc0]]], %[[Vcst]] {in_bounds = [true, true, true]} : memref<4x2x3xf32>, vector<4x2x3xf32> +// CHECK-DAG: %[[V2:.+]] = vector.extract_strided_slice %[[V0]] {offsets = [0, 0, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x4x3xf32> to vector<4x1x3xf32> +// CHECK-DAG: %[[V3:.+]] = vector.extract_strided_slice %[[V0]] {offsets = [0, 3, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x4x3xf32> to vector<4x1x3xf32> +// CHECK-DAG: %[[V4:.+]] = vector.extract_strided_slice %[[V1]] {offsets = [0, 0, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x2x3xf32> to vector<4x1x3xf32> +// CHECK-DAG: %[[V5:.+]] = vector.extract_strided_slice %[[V1]] {offsets = [0, 1, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x2x3xf32> to vector<4x1x3xf32> +// CHECK-DAG: %[[V6:.+]] = arith.addf %[[V2]], %[[V4]] : vector<4x1x3xf32> +// CHECK-DAG: %[[V7:.+]] = arith.addf %[[V3]], %[[V5]] : vector<4x1x3xf32> +// CHECK-DAG: %[[V8:.+]] = vector.insert_strided_slice %[[V6]], %[[V1]] {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<4x1x3xf32> into vector<4x2x3xf32> +// CHECK-DAG: %[[V9:.+]] = vector.insert_strided_slice %[[V7]], %[[V8]] {offsets = [0, 1, 0], strides = [1, 1, 1]} : vector<4x1x3xf32> into vector<4x2x3xf32> +// CHECK-DAG: vector.transfer_write %[[V9]], %[[Varg2]][%[[Vc0]], %[[Vc0]], %[[Vc0]]] {in_bounds = [true, true, true]} : vector<4x2x3xf32>, memref<4x2x3xf32> + +// ----- + +func.func @pooling_nwc_max_memref_1_2_1_3(%input: memref<4x4x3xf32>, %filter: memref<1xf32>, %output: memref<4x2x3xf32>) { + linalg.pooling_nwc_max + {dilations = dense<1> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>} + ins(%input, %filter : memref<4x4x3xf32>, memref<1xf32>) + outs(%output : memref<4x2x3xf32>) + return +} + +// CHECK: func.func @pooling_nwc_max_memref_1_2_1_3 +// CHECK-SAME: (%[[Varg0:.+]]: memref<4x4x3xf32>, %[[Varg1:.+]]: memref<1xf32>, %[[Varg2:.+]]: memref<4x2x3xf32>) +// CHECK-DAG: %[[Vc0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[Vcst:.+]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: %[[V0:.+]] = vector.transfer_read %[[Varg0]][%[[Vc0]], %[[Vc0]], %[[Vc0]]], %[[Vcst]] {in_bounds = [true, true, true]} : memref<4x4x3xf32>, vector<4x4x3xf32> +// CHECK-DAG: %[[V1:.+]] = vector.transfer_read %[[Varg2]][%[[Vc0]], %[[Vc0]], %[[Vc0]]], %[[Vcst]] {in_bounds = [true, true, true]} : memref<4x2x3xf32>, vector<4x2x3xf32> +// CHECK-DAG: %[[V2:.+]] = vector.extract_strided_slice %[[V0]] {offsets = [0, 0, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x4x3xf32> to vector<4x1x3xf32> +// CHECK-DAG: %[[V3:.+]] = vector.extract_strided_slice %[[V0]] {offsets = [0, 3, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x4x3xf32> to vector<4x1x3xf32> +// CHECK-DAG: %[[V4:.+]] = vector.extract_strided_slice %[[V1]] {offsets = [0, 0, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x2x3xf32> to vector<4x1x3xf32> +// CHECK-DAG: %[[V5:.+]] = vector.extract_strided_slice %[[V1]] {offsets = [0, 1, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x2x3xf32> to vector<4x1x3xf32> +// CHECK-DAG: %[[V6:.+]] = arith.maxf %[[V2]], %[[V4]] : vector<4x1x3xf32> +// CHECK-DAG: %[[V7:.+]] = arith.maxf %[[V3]], %[[V5]] : vector<4x1x3xf32> +// CHECK-DAG: %[[V8:.+]] = vector.insert_strided_slice %[[V6]], %[[V1]] {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<4x1x3xf32> into vector<4x2x3xf32> +// CHECK-DAG: %[[V9:.+]] = vector.insert_strided_slice %[[V7]], %[[V8]] {offsets = [0, 1, 0], strides = [1, 1, 1]} : vector<4x1x3xf32> into vector<4x2x3xf32> +// CHECK-DAG: vector.transfer_write %[[V9]], %[[Varg2]][%[[Vc0]], %[[Vc0]], %[[Vc0]]] {in_bounds = [true, true, true]} : vector<4x2x3xf32>, memref<4x2x3xf32> + +// ----- + +// The i8i8i32 case is similar to f32 case, so checking one case is enough for +// test coverage. +func.func @pooling_nwc_sum_i8i8i32_memref_1_2_1_3(%input: memref<4x4x3xi8>, %filter: memref<1xi8>, %output: memref<4x2x3xi32>) { + linalg.pooling_nwc_sum + {dilations = dense<1> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>} + ins(%input, %filter : memref<4x4x3xi8>, memref<1xi8>) + outs(%output : memref<4x2x3xi32>) + return +} + +// CHECK: func.func @pooling_nwc_sum_i8i8i32_memref_1_2_1_3 +// CHECK-SAME: (%[[Varg0:.+]]: memref<4x4x3xi8>, %[[Varg1:.+]]: memref<1xi8>, %[[Varg2:.+]]: memref<4x2x3xi32>) +// CHECK-DAG: %[[Vc0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[Vc0_i8:.+]] = arith.constant 0 : i8 +// CHECK-DAG: %[[Vc0_i32:.+]] = arith.constant 0 : i32 +// CHECK-DAG: %[[V0:.+]] = vector.transfer_read %[[Varg0]][%[[Vc0]], %[[Vc0]], %[[Vc0]]], %[[Vc0_i8]] {in_bounds = [true, true, true]} : memref<4x4x3xi8>, vector<4x4x3xi8> +// CHECK-DAG: %[[V1:.+]] = vector.transfer_read %[[Varg2]][%[[Vc0]], %[[Vc0]], %[[Vc0]]], %[[Vc0_i32]] {in_bounds = [true, true, true]} : memref<4x2x3xi32>, vector<4x2x3xi32> +// CHECK-DAG: %[[V2:.+]] = vector.extract_strided_slice %[[V0]] {offsets = [0, 0, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x4x3xi8> to vector<4x1x3xi8> +// CHECK-DAG: %[[V3:.+]] = vector.extract_strided_slice %[[V0]] {offsets = [0, 3, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x4x3xi8> to vector<4x1x3xi8> +// CHECK-DAG: %[[V4:.+]] = vector.extract_strided_slice %[[V1]] {offsets = [0, 0, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x2x3xi32> to vector<4x1x3xi32> +// CHECK-DAG: %[[V5:.+]] = vector.extract_strided_slice %[[V1]] {offsets = [0, 1, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x2x3xi32> to vector<4x1x3xi32> +// CHECK-DAG: %[[V6:.+]] = arith.extsi %[[V2]] : vector<4x1x3xi8> to vector<4x1x3xi32> +// CHECK-DAG: %[[V7:.+]] = arith.addi %[[V6]], %[[V4]] : vector<4x1x3xi32> +// CHECK-DAG: %[[V8:.+]] = arith.extsi %[[V3]] : vector<4x1x3xi8> to vector<4x1x3xi32> +// CHECK-DAG: %[[V9:.+]] = arith.addi %[[V8]], %[[V5]] : vector<4x1x3xi32> +// CHECK-DAG: %[[V10:.+]] = vector.insert_strided_slice %[[V7]], %[[V1]] {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<4x1x3xi32> into vector<4x2x3xi32> +// CHECK-DAG: %[[V11:.+]] = vector.insert_strided_slice %[[V9]], %[[V10]] {offsets = [0, 1, 0], strides = [1, 1, 1]} : vector<4x1x3xi32> into vector<4x2x3xi32> +// CHECK-DAG: vector.transfer_write %[[V11]], %[[Varg2]][%[[Vc0]], %[[Vc0]], %[[Vc0]]] {in_bounds = [true, true, true]} : vector<4x2x3xi32>, memref<4x2x3xi32> +// CHECK-DAG: return + +// ----- + +// The i8i8i32 case is similar to f32 case, so checking one case is enough for +// test coverage. +func.func @pooling_nwc_max_i8i8i32_memref_1_2_1_3(%input: memref<4x4x3xi8>, %filter: memref<1xi8>, %output: memref<4x2x3xi32>) { + linalg.pooling_nwc_max + {dilations = dense<1> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>} + ins(%input, %filter : memref<4x4x3xi8>, memref<1xi8>) + outs(%output : memref<4x2x3xi32>) + return +} + +// CHECK: func.func @pooling_nwc_max_i8i8i32_memref_1_2_1_3 +// CHECK-SAME: (%[[Varg0:.+]]: memref<4x4x3xi8>, %[[Varg1:.+]]: memref<1xi8>, %[[Varg2:.+]]: memref<4x2x3xi32>) +// CHECK-DAG: %[[Vc0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[Vc0_i8:.+]] = arith.constant 0 : i8 +// CHECK-DAG: %[[Vc0_i32:.+]] = arith.constant 0 : i32 +// CHECK-DAG: %[[V0:.+]] = vector.transfer_read %[[Varg0]][%[[Vc0]], %[[Vc0]], %[[Vc0]]], %[[Vc0_i8]] {in_bounds = [true, true, true]} : memref<4x4x3xi8>, vector<4x4x3xi8> +// CHECK-DAG: %[[V1:.+]] = vector.transfer_read %[[Varg2]][%[[Vc0]], %[[Vc0]], %[[Vc0]]], %[[Vc0_i32]] {in_bounds = [true, true, true]} : memref<4x2x3xi32>, vector<4x2x3xi32> +// CHECK-DAG: %[[V2:.+]] = vector.extract_strided_slice %[[V0]] {offsets = [0, 0, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x4x3xi8> to vector<4x1x3xi8> +// CHECK-DAG: %[[V3:.+]] = vector.extract_strided_slice %[[V0]] {offsets = [0, 3, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x4x3xi8> to vector<4x1x3xi8> +// CHECK-DAG: %[[V4:.+]] = vector.extract_strided_slice %[[V1]] {offsets = [0, 0, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x2x3xi32> to vector<4x1x3xi32> +// CHECK-DAG: %[[V5:.+]] = vector.extract_strided_slice %[[V1]] {offsets = [0, 1, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x2x3xi32> to vector<4x1x3xi32> +// CHECK-DAG: %[[V6:.+]] = arith.extsi %[[V2]] : vector<4x1x3xi8> to vector<4x1x3xi32> +// CHECK-DAG: %[[V7:.+]] = arith.maxsi %[[V6]], %[[V4]] : vector<4x1x3xi32> +// CHECK-DAG: %[[V8:.+]] = arith.extsi %[[V3]] : vector<4x1x3xi8> to vector<4x1x3xi32> +// CHECK-DAG: %[[V9:.+]] = arith.maxsi %[[V8]], %[[V5]] : vector<4x1x3xi32> +// CHECK-DAG: %[[V10:.+]] = vector.insert_strided_slice %[[V7]], %[[V1]] {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<4x1x3xi32> into vector<4x2x3xi32> +// CHECK-DAG: %[[V11:.+]] = vector.insert_strided_slice %[[V9]], %[[V10]] {offsets = [0, 1, 0], strides = [1, 1, 1]} : vector<4x1x3xi32> into vector<4x2x3xi32> +// CHECK-DAG: vector.transfer_write %[[V11]], %[[Varg2]][%[[Vc0]], %[[Vc0]], %[[Vc0]]] {in_bounds = [true, true, true]} : vector<4x2x3xi32>, memref<4x2x3xi32> +// CHECK-DAG: return + +// ----- + +func.func @pooling_nwc_sum_memref_2_2_2_3(%input: memref<4x6x3xf32>, %filter: memref<2xf32>, %output: memref<4x2x3xf32>) { + linalg.pooling_nwc_sum + {dilations = dense<2> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>} + ins(%input, %filter : memref<4x6x3xf32>, memref<2xf32>) + outs(%output : memref<4x2x3xf32>) + return +} + +// CHECK: func.func @pooling_nwc_sum_memref_2_2_2_3 +// CHECK-SAME: (%[[Varg0:.+]]: memref<4x6x3xf32>, %[[Varg1:.+]]: memref<2xf32>, %[[Varg2:.+]]: memref<4x2x3xf32>) +// CHECK-DAG: %[[Vc0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[Vcst:.+]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: %[[V0:.+]] = vector.transfer_read %[[Varg0]][%[[Vc0]], %[[Vc0]], %[[Vc0]]], %[[Vcst]] {in_bounds = [true, true, true]} : memref<4x6x3xf32>, vector<4x6x3xf32> +// CHECK-DAG: %[[V1:.+]] = vector.transfer_read %[[Varg2]][%[[Vc0]], %[[Vc0]], %[[Vc0]]], %[[Vcst]] {in_bounds = [true, true, true]} : memref<4x2x3xf32>, vector<4x2x3xf32> +// CHECK-DAG: %[[V2:.+]] = vector.extract_strided_slice %[[V0]] {offsets = [0, 0, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x6x3xf32> to vector<4x1x3xf32> +// CHECK-DAG: %[[V3:.+]] = vector.extract_strided_slice %[[V0]] {offsets = [0, 3, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x6x3xf32> to vector<4x1x3xf32> +// CHECK-DAG: %[[V4:.+]] = vector.extract_strided_slice %[[V0]] {offsets = [0, 2, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x6x3xf32> to vector<4x1x3xf32> +// CHECK-DAG: %[[V5:.+]] = vector.extract_strided_slice %[[V0]] {offsets = [0, 5, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x6x3xf32> to vector<4x1x3xf32> +// CHECK-DAG: %[[V6:.+]] = vector.extract_strided_slice %[[V1]] {offsets = [0, 0, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x2x3xf32> to vector<4x1x3xf32> +// CHECK-DAG: %[[V7:.+]] = vector.extract_strided_slice %[[V1]] {offsets = [0, 1, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x2x3xf32> to vector<4x1x3xf32> +// CHECK-DAG: %[[V8:.+]] = arith.addf %[[V2]], %[[V6]] : vector<4x1x3xf32> +// CHECK-DAG: %[[V9:.+]] = arith.addf %[[V3]], %[[V7]] : vector<4x1x3xf32> +// CHECK-DAG: %[[V10:.+]] = arith.addf %[[V4]], %[[V8]] : vector<4x1x3xf32> +// CHECK-DAG: %[[V11:.+]] = arith.addf %[[V5]], %[[V9]] : vector<4x1x3xf32> +// CHECK-DAG: %[[V12:.+]] = vector.insert_strided_slice %[[V10]], %[[V1]] {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<4x1x3xf32> into vector<4x2x3xf32> +// CHECK-DAG: %[[V13:.+]] = vector.insert_strided_slice %[[V11]], %[[V12]] {offsets = [0, 1, 0], strides = [1, 1, 1]} : vector<4x1x3xf32> into vector<4x2x3xf32> +// CHECK-DAG: vector.transfer_write %[[V13:.+]], %[[Varg2]][%[[Vc0]], %[[Vc0]], %[[Vc0]]] {in_bounds = [true, true, true]} : vector<4x2x3xf32>, memref<4x2x3xf32> + + +// ----- + +func.func @pooling_ncw_sum_memref_1_2_1_3(%input: memref<4x3x4xf32>, %filter: memref<1xf32>, %output: memref<4x3x2xf32>) { + linalg.pooling_ncw_sum + {dilations = dense<1> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>} + ins(%input, %filter : memref<4x3x4xf32>, memref<1xf32>) + outs(%output : memref<4x3x2xf32>) + return +} + +// CHECK: func.func @pooling_ncw_sum_memref_1_2_1_3 +// CHECK-SAME: (%[[Varg0:.+]]: memref<4x3x4xf32>, %[[Varg1:.+]]: memref<1xf32>, %[[Varg2:.+]]: memref<4x3x2xf32>) +// CHECK-DAG: %[[Vc0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[Vcst:.+]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: %[[V0:.+]] = vector.transfer_read %[[Varg0]][%[[Vc0]], %[[Vc0]], %[[Vc0]]], %[[Vcst]] {in_bounds = [true, true, true]} : memref<4x3x4xf32>, vector<4x3x4xf32> +// CHECK-DAG: %[[V1:.+]] = vector.transfer_read %[[Varg2]][%[[Vc0]], %[[Vc0]], %[[Vc0]]], %[[Vcst]] {in_bounds = [true, true, true]} : memref<4x3x2xf32>, vector<4x3x2xf32> +// CHECK-DAG: %[[V2:.+]] = vector.transpose %[[V0]], [0, 2, 1] : vector<4x3x4xf32> to vector<4x4x3xf32> +// CHECK-DAG: %[[V3:.+]] = vector.transpose %[[V1]], [0, 2, 1] : vector<4x3x2xf32> to vector<4x2x3xf32> +// CHECK-DAG: %[[V4:.+]] = vector.extract_strided_slice %[[V2]] {offsets = [0, 0, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x4x3xf32> to vector<4x1x3xf32> +// CHECK-DAG: %[[V5:.+]] = vector.extract_strided_slice %[[V2]] {offsets = [0, 3, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x4x3xf32> to vector<4x1x3xf32> +// CHECK-DAG: %[[V6:.+]] = vector.extract_strided_slice %[[V3]] {offsets = [0, 0, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x2x3xf32> to vector<4x1x3xf32> +// CHECK-DAG: %[[V7:.+]] = vector.extract_strided_slice %[[V3]] {offsets = [0, 1, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x2x3xf32> to vector<4x1x3xf32> +// CHECK-DAG: %[[V8:.+]] = arith.addf %[[V4]], %[[V6]] : vector<4x1x3xf32> +// CHECK-DAG: %[[V9:.+]] = arith.addf %[[V5]], %[[V7]] : vector<4x1x3xf32> +// CHECK-DAG: %[[V10:.+]] = vector.insert_strided_slice %[[V8]], %[[V3]] {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<4x1x3xf32> into vector<4x2x3xf32> +// CHECK-DAG: %[[V11:.+]] = vector.insert_strided_slice %[[V9]], %[[V10]] {offsets = [0, 1, 0], strides = [1, 1, 1]} : vector<4x1x3xf32> into vector<4x2x3xf32> +// CHECK-DAG: %[[V12:.+]] = vector.transpose %[[V11]], [0, 2, 1] : vector<4x2x3xf32> to vector<4x3x2xf32> +// CHECK-DAG: vector.transfer_write %[[V12:.+]], %[[Varg2]][%[[Vc0]], %[[Vc0]], %[[Vc0]]] {in_bounds = [true, true, true]} : vector<4x3x2xf32>, memref<4x3x2xf32> + + +// ----- + +func.func @pooling_nwc_sum_mixed_type_memref_1_2_1_1(%input: memref<1x2x3xf16>, %filter: memref<1xf16>, %output: memref<1x2x3xf32>) { + linalg.pooling_nwc_sum + {dilations = dense<1> : vector<1xi64>, strides = dense<1> : vector<1xi64>} + ins(%input, %filter : memref<1x2x3xf16>, memref<1xf16>) + outs(%output : memref<1x2x3xf32>) + return +} + +// CHECK: func.func @pooling_nwc_sum_mixed_type_memref_1_2_1_1 +// CHECK-SAME: (%[[Varg0:.+]]: memref<1x2x3xf16>, %[[Varg1:.+]]: memref<1xf16>, %[[Varg2:.+]]: memref<1x2x3xf32>) +// CHECK-DAG: %[[Vc0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[Vcst:.+]] = arith.constant 0.000000e+00 : f16 +// CHECK-DAG: %[[Vcst_0:.+]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: %[[V0:.+]] = vector.transfer_read %[[Varg0]][%[[Vc0]], %[[Vc0]], %[[Vc0]]], %[[Vcst]] {in_bounds = [true, true, true]} : memref<1x2x3xf16>, vector<1x2x3xf16> +// CHECK-DAG: %[[V1:.+]] = vector.transfer_read %[[Varg2]][%[[Vc0]], %[[Vc0]], %[[Vc0]]], %[[Vcst_0]] {in_bounds = [true, true, true]} : memref<1x2x3xf32>, vector<1x2x3xf32> +// CHECK-DAG: %[[V2:.+]] = arith.extf %[[V0]] : vector<1x2x3xf16> to vector<1x2x3xf32> +// CHECK-DAG: %[[V3:.+]] = arith.addf %[[V2]], %[[V1]] : vector<1x2x3xf32> +// CHECK-DAG: vector.transfer_write %[[V3:.+]], %[[Varg2]][%[[Vc0]], %[[Vc0]], %[[Vc0]]] {in_bounds = [true, true, true]} : vector<1x2x3xf32>, memref<1x2x3xf32> + +// ----- + +func.func @pooling_nwc_sum_memref_2_2_2_1(%input: memref<4x4x3xf32>, %filter: memref<2xf32>, %output: memref<4x2x3xf32>) { + linalg.pooling_nwc_sum + {dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} + ins(%input, %filter : memref<4x4x3xf32>, memref<2xf32>) + outs(%output : memref<4x2x3xf32>) + return +} + +// CHECK: func.func @pooling_nwc_sum_memref_2_2_2_1 +// CHECK-SAME: (%[[Varg0:.+]]: memref<4x4x3xf32>, %[[Varg1:.+]]: memref<2xf32>, %[[Varg2:.+]]: memref<4x2x3xf32>) +// CHECK-DAG: %[[Vc0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[Vcst:.+]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: %[[V0:.+]] = vector.transfer_read %[[Varg0]][%[[Vc0]], %[[Vc0]], %[[Vc0]]], %[[Vcst]] {in_bounds = [true, true, true]} : memref<4x4x3xf32>, vector<4x4x3xf32> +// CHECK-DAG: %[[V1:.+]] = vector.transfer_read %[[Varg2]][%[[Vc0]], %[[Vc0]], %[[Vc0]]], %[[Vcst]] {in_bounds = [true, true, true]} : memref<4x2x3xf32>, vector<4x2x3xf32> +// CHECK-DAG: %[[V2:.+]] = vector.extract_strided_slice %[[V0]] {offsets = [0, 0, 0], sizes = [4, 2, 3], strides = [1, 1, 1]} : vector<4x4x3xf32> to vector<4x2x3xf32> +// CHECK-DAG: %[[V3:.+]] = vector.extract_strided_slice %[[V0]] {offsets = [0, 2, 0], sizes = [4, 2, 3], strides = [1, 1, 1]} : vector<4x4x3xf32> to vector<4x2x3xf32> +// CHECK-DAG: %[[V4:.+]] = arith.addf %[[V2]], %[[V1]] : vector<4x2x3xf32> +// CHECK-DAG: %[[V5:.+]] = arith.addf %[[V3]], %[[V4]] : vector<4x2x3xf32> +// CHECK-DAG: vector.transfer_write %[[V5:.+]], %[[Varg2]][%[[Vc0]], %[[Vc0]], %[[Vc0]]] {in_bounds = [true, true, true]} : vector<4x2x3xf32>, memref<4x2x3xf32> + + +// ----- + +func.func @pooling_ncw_sum_memref_2_2_2_3(%input: memref<4x3x6xf32>, %filter: memref<2xf32>, %output: memref<4x3x2xf32>) { + linalg.pooling_ncw_sum + {dilations = dense<2> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>} + ins(%input, %filter : memref<4x3x6xf32>, memref<2xf32>) + outs(%output : memref<4x3x2xf32>) + return +} + +// CHECK: func.func @pooling_ncw_sum_memref_2_2_2_3 +// CHECK-SAME: (%[[Varg0:.+]]: memref<4x3x6xf32>, %[[Varg1:.+]]: memref<2xf32>, %[[Varg2:.+]]: memref<4x3x2xf32>) +// CHECK-DAG: %[[Vc0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[Vcst:.+]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: %[[V0:.+]] = vector.transfer_read %[[Varg0]][%[[Vc0]], %[[Vc0]], %[[Vc0]]], %[[Vcst]] {in_bounds = [true, true, true]} : memref<4x3x6xf32>, vector<4x3x6xf32> +// CHECK-DAG: %[[V1:.+]] = vector.transfer_read %[[Varg2]][%[[Vc0]], %[[Vc0]], %[[Vc0]]], %[[Vcst]] {in_bounds = [true, true, true]} : memref<4x3x2xf32>, vector<4x3x2xf32> +// CHECK-DAG: %[[V2:.+]] = vector.transpose %[[V0]], [0, 2, 1] : vector<4x3x6xf32> to vector<4x6x3xf32> +// CHECK-DAG: %[[V3:.+]] = vector.transpose %[[V1]], [0, 2, 1] : vector<4x3x2xf32> to vector<4x2x3xf32> +// CHECK-DAG: %[[V4:.+]] = vector.extract_strided_slice %[[V2]] {offsets = [0, 0, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x6x3xf32> to vector<4x1x3xf32> +// CHECK-DAG: %[[V5:.+]] = vector.extract_strided_slice %[[V2]] {offsets = [0, 3, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x6x3xf32> to vector<4x1x3xf32> +// CHECK-DAG: %[[V6:.+]] = vector.extract_strided_slice %[[V2]] {offsets = [0, 2, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x6x3xf32> to vector<4x1x3xf32> +// CHECK-DAG: %[[V7:.+]] = vector.extract_strided_slice %[[V2]] {offsets = [0, 5, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x6x3xf32> to vector<4x1x3xf32> +// CHECK-DAG: %[[V8:.+]] = vector.extract_strided_slice %[[V3]] {offsets = [0, 0, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x2x3xf32> to vector<4x1x3xf32> +// CHECK-DAG: %[[V9:.+]] = vector.extract_strided_slice %[[V3]] {offsets = [0, 1, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x2x3xf32> to vector<4x1x3xf32> +// CHECK-DAG: %[[V10:.+]] = arith.addf %[[V4]], %[[V8]] : vector<4x1x3xf32> +// CHECK-DAG: %[[V11:.+]] = arith.addf %[[V5]], %[[V9]] : vector<4x1x3xf32> +// CHECK-DAG: %[[V12:.+]] = arith.addf %[[V6]], %[[V10]] : vector<4x1x3xf32> +// CHECK-DAG: %[[V13:.+]] = arith.addf %[[V7]], %[[V11]] : vector<4x1x3xf32> +// CHECK-DAG: %[[V14:.+]] = vector.insert_strided_slice %[[V12]], %[[V3]] {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<4x1x3xf32> into vector<4x2x3xf32> +// CHECK-DAG: %[[V15:.+]] = vector.insert_strided_slice %[[V13]], %[[V14]] {offsets = [0, 1, 0], strides = [1, 1, 1]} : vector<4x1x3xf32> into vector<4x2x3xf32> +// CHECK-DAG: %[[V16:.+]] = vector.transpose %[[V15]], [0, 2, 1] : vector<4x2x3xf32> to vector<4x3x2xf32> +// CHECK-DAG: vector.transfer_write %[[V16:.+]], %[[Varg2]][%[[Vc0]], %[[Vc0]], %[[Vc0]]] {in_bounds = [true, true, true]} : vector<4x3x2xf32>, memref<4x3x2xf32> + +// ----- + +func.func @pooling_ncw_sum_memref_2_3_2_1(%input: memref<4x2x5xf32>, %filter: memref<2xf32>, %output: memref<4x2x3xf32>) { + linalg.pooling_ncw_sum + {dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} + ins(%input, %filter : memref<4x2x5xf32>, memref<2xf32>) + outs(%output : memref<4x2x3xf32>) + return +} + +// CHECK: func.func @pooling_ncw_sum_memref_2_3_2_1 +// CHECK-SAME: (%[[Varg0:.+]]: memref<4x2x5xf32>, %[[Varg1:.+]]: memref<2xf32>, %[[Varg2:.+]]: memref<4x2x3xf32>) +// CHECK-DAG: %[[Vc0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[Vcst:.+]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: %[[V0:.+]] = vector.transfer_read %[[Varg0]][%[[Vc0]], %[[Vc0]], %[[Vc0]]], %[[Vcst]] {in_bounds = [true, true, true]} : memref<4x2x5xf32>, vector<4x2x5xf32> +// CHECK-DAG: %[[V1:.+]] = vector.transfer_read %[[Varg2]][%[[Vc0]], %[[Vc0]], %[[Vc0]]], %[[Vcst]] {in_bounds = [true, true, true]} : memref<4x2x3xf32>, vector<4x2x3xf32> +// CHECK-DAG: %[[V2:.+]] = vector.transpose %[[V0]], [0, 2, 1] : vector<4x2x5xf32> to vector<4x5x2xf32> +// CHECK-DAG: %[[V3:.+]] = vector.transpose %[[V1]], [0, 2, 1] : vector<4x2x3xf32> to vector<4x3x2xf32> +// CHECK-DAG: %[[V4:.+]] = vector.extract_strided_slice %[[V2]] {offsets = [0, 0, 0], sizes = [4, 3, 2], strides = [1, 1, 1]} : vector<4x5x2xf32> to vector<4x3x2xf32> +// CHECK-DAG: %[[V5:.+]] = vector.extract_strided_slice %[[V2]] {offsets = [0, 2, 0], sizes = [4, 3, 2], strides = [1, 1, 1]} : vector<4x5x2xf32> to vector<4x3x2xf32> +// CHECK-DAG: %[[V6:.+]] = arith.addf %[[V4]], %[[V3]] : vector<4x3x2xf32> +// CHECK-DAG: %[[V7:.+]] = arith.addf %[[V5]], %[[V6]] : vector<4x3x2xf32> +// CHECK-DAG: %[[V8:.+]] = vector.transpose %[[V7]], [0, 2, 1] : vector<4x3x2xf32> to vector<4x2x3xf32> +// CHECK-DAG: vector.transfer_write %[[V8:.+]], %[[Varg2]][%[[Vc0]], %[[Vc0]], %[[Vc0]]] {in_bounds = [true, true, true]} : vector<4x2x3xf32>, memref<4x2x3xf32>