Index: mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml =================================================================== --- mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml +++ mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml @@ -661,7 +661,7 @@ The partial multiplication results are reduced into a 2D output. Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output." + them to the same data type as the accumulator/output. implements: - LinalgContractionOpInterface structured_op: !LinalgStructuredOpConfig @@ -2279,38 +2279,39 @@ name: I kind: input_tensor type_var: T1 - shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s0, s9, s1 * - s2 + s3 * s4, s5 * s6 + s7 * s8)> + shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s0, s1, s2 + * s3 + s4 * s5, s6 * s7 + s8 * s9)> - !LinalgOperandDefConfig name: K kind: input_tensor type_var: T2 - shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s9, s3, s7)> + shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s1, s4, s8)> - !LinalgOperandDefConfig name: O kind: output_tensor type_var: U - shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s0, s9, s1, s5)> + shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s0, s1, s2, + s6)> - !LinalgOperandDefConfig name: strides kind: index_attr - index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s2, - s6)> + index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s3, + s7)> default_indices: - 1 - 1 - !LinalgOperandDefConfig name: dilations kind: index_attr - index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s4, - s8)> + index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s5, + s9)> default_indices: - 1 - 1 indexing_maps: !LinalgIndexingMapsConfig static_indexing_maps: - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] - -> (d0, d3, d1 * s2 + d4 * s4, d2 * s6 + d5 * s8)> + -> (d0, d3, d1 * s3 + d4 * s5, d2 * s7 + d5 * s9)> - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (d3, d4, d5)> - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] @@ -3470,6 +3471,497 @@ - !ScalarExpression scalar_arg: I --- !LinalgOpConfig +metadata: !LinalgOpMetadata + name: pooling_nwc_sum + cpp_class_name: PoolingNwcSumOp + doc: |- + Performs sum pooling. + + Layout: + * Input: NWC. + * Kernel: W. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + implements: + - LinalgConvolutionOpInterface +structured_op: !LinalgStructuredOpConfig + args: + - !LinalgOperandDefConfig + name: I + kind: input_tensor + type_var: T1 + shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s0, s1 * s2 + s3 * s4, s5)> + - !LinalgOperandDefConfig + name: K + kind: input_tensor + type_var: T2 + shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s3)> + - !LinalgOperandDefConfig + name: O + kind: output_tensor + type_var: U + shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s0, s1, s5)> + - !LinalgOperandDefConfig + name: strides + kind: index_attr + index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s2)> + default_indices: + - 1 + - !LinalgOperandDefConfig + name: dilations + kind: index_attr + index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s4)> + default_indices: + - 1 + indexing_maps: !LinalgIndexingMapsConfig + static_indexing_maps: + - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3, s4, s5] -> (d0, d1 * s2 + d3 * s4, + d2)> + - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3, s4, s5] -> (d3)> + - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3, s4, s5] -> (d0, d1, d2)> + iterator_types: + - parallel + - parallel + - parallel + - reduction + assignments: + - !ScalarAssign + arg: O + value: !ScalarExpression + scalar_fn: + kind: binary + fn_name: add + operands: + - !ScalarExpression + scalar_arg: O + - !ScalarExpression + scalar_fn: + kind: type + fn_name: cast_signed + type_var: U + operands: + - !ScalarExpression + scalar_arg: I +--- !LinalgOpConfig +metadata: !LinalgOpMetadata + name: pooling_ncw_sum + cpp_class_name: PoolingNcwSumOp + doc: |- + Performs sum pooling. + + Layout: + * Input: NCW. + * Kernel: W. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + implements: + - LinalgConvolutionOpInterface +structured_op: !LinalgStructuredOpConfig + args: + - !LinalgOperandDefConfig + name: I + kind: input_tensor + type_var: T1 + shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s0, s1, s2 * s3 + s4 * s5)> + - !LinalgOperandDefConfig + name: K + kind: input_tensor + type_var: T2 + shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s4)> + - !LinalgOperandDefConfig + name: O + kind: output_tensor + type_var: U + shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s0, s1, s2)> + - !LinalgOperandDefConfig + name: strides + kind: index_attr + index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s3)> + default_indices: + - 1 + - !LinalgOperandDefConfig + name: dilations + kind: index_attr + index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s5)> + default_indices: + - 1 + indexing_maps: !LinalgIndexingMapsConfig + static_indexing_maps: + - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3, s4, s5] -> (d0, d1, d2 * s3 + d3 + * s5)> + - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3, s4, s5] -> (d3)> + - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3, s4, s5] -> (d0, d1, d2)> + iterator_types: + - parallel + - parallel + - parallel + - reduction + assignments: + - !ScalarAssign + arg: O + value: !ScalarExpression + scalar_fn: + kind: binary + fn_name: add + operands: + - !ScalarExpression + scalar_arg: O + - !ScalarExpression + scalar_fn: + kind: type + fn_name: cast_signed + type_var: U + operands: + - !ScalarExpression + scalar_arg: I +--- !LinalgOpConfig +metadata: !LinalgOpMetadata + name: pooling_nwc_max + cpp_class_name: PoolingNwcMaxOp + doc: |- + Performs max pooling. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + implements: + - LinalgConvolutionOpInterface +structured_op: !LinalgStructuredOpConfig + args: + - !LinalgOperandDefConfig + name: I + kind: input_tensor + type_var: T1 + shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s0, s1 * s2 + s3 * s4, s5)> + - !LinalgOperandDefConfig + name: K + kind: input_tensor + type_var: T2 + shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s3)> + - !LinalgOperandDefConfig + name: O + kind: output_tensor + type_var: U + shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s0, s1, s5)> + - !LinalgOperandDefConfig + name: strides + kind: index_attr + index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s2)> + default_indices: + - 1 + - !LinalgOperandDefConfig + name: dilations + kind: index_attr + index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s4)> + default_indices: + - 1 + indexing_maps: !LinalgIndexingMapsConfig + static_indexing_maps: + - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3, s4, s5] -> (d0, d1 * s2 + d3 * s4, + d2)> + - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3, s4, s5] -> (d3)> + - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3, s4, s5] -> (d0, d1, d2)> + iterator_types: + - parallel + - parallel + - parallel + - reduction + assignments: + - !ScalarAssign + arg: O + value: !ScalarExpression + scalar_fn: + kind: binary + fn_name: max_signed + operands: + - !ScalarExpression + scalar_arg: O + - !ScalarExpression + scalar_fn: + kind: type + fn_name: cast_signed + type_var: U + operands: + - !ScalarExpression + scalar_arg: I +--- !LinalgOpConfig +metadata: !LinalgOpMetadata + name: pooling_nwc_max_unsigned + cpp_class_name: PoolingNwcMaxUnsignedOp + doc: |- + Performs unsigned max pooling. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + implements: + - LinalgConvolutionOpInterface +structured_op: !LinalgStructuredOpConfig + args: + - !LinalgOperandDefConfig + name: I + kind: input_tensor + type_var: T1 + shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s0, s1 * s2 + s3 * s4, s5)> + - !LinalgOperandDefConfig + name: K + kind: input_tensor + type_var: T2 + shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s3)> + - !LinalgOperandDefConfig + name: O + kind: output_tensor + type_var: U + shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s0, s1, s5)> + - !LinalgOperandDefConfig + name: strides + kind: index_attr + index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s2)> + default_indices: + - 1 + - !LinalgOperandDefConfig + name: dilations + kind: index_attr + index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s4)> + default_indices: + - 1 + indexing_maps: !LinalgIndexingMapsConfig + static_indexing_maps: + - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3, s4, s5] -> (d0, d1 * s2 + d3 * s4, + d2)> + - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3, s4, s5] -> (d3)> + - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3, s4, s5] -> (d0, d1, d2)> + iterator_types: + - parallel + - parallel + - parallel + - reduction + assignments: + - !ScalarAssign + arg: O + value: !ScalarExpression + scalar_fn: + kind: binary + fn_name: max_unsigned + operands: + - !ScalarExpression + scalar_arg: O + - !ScalarExpression + scalar_fn: + kind: type + fn_name: cast_unsigned + type_var: U + operands: + - !ScalarExpression + scalar_arg: I +--- !LinalgOpConfig +metadata: !LinalgOpMetadata + name: pooling_ncw_max + cpp_class_name: PoolingNcwMaxOp + doc: |- + Performs max pooling. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + implements: + - LinalgConvolutionOpInterface +structured_op: !LinalgStructuredOpConfig + args: + - !LinalgOperandDefConfig + name: I + kind: input_tensor + type_var: T1 + shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s0, s1, s2 * s3 + s4 * s5)> + - !LinalgOperandDefConfig + name: K + kind: input_tensor + type_var: T2 + shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s4)> + - !LinalgOperandDefConfig + name: O + kind: output_tensor + type_var: U + shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s0, s1, s2)> + - !LinalgOperandDefConfig + name: strides + kind: index_attr + index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s3)> + default_indices: + - 1 + - !LinalgOperandDefConfig + name: dilations + kind: index_attr + index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s5)> + default_indices: + - 1 + indexing_maps: !LinalgIndexingMapsConfig + static_indexing_maps: + - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3, s4, s5] -> (d0, d1, d2 * s3 + d3 + * s5)> + - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3, s4, s5] -> (d3)> + - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3, s4, s5] -> (d0, d1, d2)> + iterator_types: + - parallel + - parallel + - parallel + - reduction + assignments: + - !ScalarAssign + arg: O + value: !ScalarExpression + scalar_fn: + kind: binary + fn_name: max_signed + operands: + - !ScalarExpression + scalar_arg: O + - !ScalarExpression + scalar_fn: + kind: type + fn_name: cast_signed + type_var: U + operands: + - !ScalarExpression + scalar_arg: I +--- !LinalgOpConfig +metadata: !LinalgOpMetadata + name: pooling_nwc_min + cpp_class_name: PoolingNwcMinOp + doc: |- + Performs min pooling. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + implements: + - LinalgConvolutionOpInterface +structured_op: !LinalgStructuredOpConfig + args: + - !LinalgOperandDefConfig + name: I + kind: input_tensor + type_var: T1 + shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s0, s1 * s2 + s3 * s4, s5)> + - !LinalgOperandDefConfig + name: K + kind: input_tensor + type_var: T2 + shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s3)> + - !LinalgOperandDefConfig + name: O + kind: output_tensor + type_var: U + shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s0, s1, s5)> + - !LinalgOperandDefConfig + name: strides + kind: index_attr + index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s2)> + default_indices: + - 1 + - !LinalgOperandDefConfig + name: dilations + kind: index_attr + index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s4)> + default_indices: + - 1 + indexing_maps: !LinalgIndexingMapsConfig + static_indexing_maps: + - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3, s4, s5] -> (d0, d1 * s2 + d3 * s4, + d2)> + - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3, s4, s5] -> (d3)> + - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3, s4, s5] -> (d0, d1, d2)> + iterator_types: + - parallel + - parallel + - parallel + - reduction + assignments: + - !ScalarAssign + arg: O + value: !ScalarExpression + scalar_fn: + kind: binary + fn_name: min_signed + operands: + - !ScalarExpression + scalar_arg: O + - !ScalarExpression + scalar_fn: + kind: type + fn_name: cast_signed + type_var: U + operands: + - !ScalarExpression + scalar_arg: I +--- !LinalgOpConfig +metadata: !LinalgOpMetadata + name: pooling_nwc_min_unsigned + cpp_class_name: PoolingNwcMinUnsignedOp + doc: |- + Performs unsigned min pooling. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + implements: + - LinalgConvolutionOpInterface +structured_op: !LinalgStructuredOpConfig + args: + - !LinalgOperandDefConfig + name: I + kind: input_tensor + type_var: T1 + shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s0, s1 * s2 + s3 * s4, s5)> + - !LinalgOperandDefConfig + name: K + kind: input_tensor + type_var: T2 + shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s3)> + - !LinalgOperandDefConfig + name: O + kind: output_tensor + type_var: U + shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s0, s1, s5)> + - !LinalgOperandDefConfig + name: strides + kind: index_attr + index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s2)> + default_indices: + - 1 + - !LinalgOperandDefConfig + name: dilations + kind: index_attr + index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s4)> + default_indices: + - 1 + indexing_maps: !LinalgIndexingMapsConfig + static_indexing_maps: + - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3, s4, s5] -> (d0, d1 * s2 + d3 * s4, + d2)> + - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3, s4, s5] -> (d3)> + - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3, s4, s5] -> (d0, d1, d2)> + iterator_types: + - parallel + - parallel + - parallel + - reduction + assignments: + - !ScalarAssign + arg: O + value: !ScalarExpression + scalar_fn: + kind: binary + fn_name: min_unsigned + operands: + - !ScalarExpression + scalar_arg: O + - !ScalarExpression + scalar_fn: + kind: type + fn_name: cast_unsigned + type_var: U + operands: + - !ScalarExpression + scalar_arg: I +--- !LinalgOpConfig metadata: !LinalgOpMetadata name: pooling_ndhwc_sum cpp_class_name: PoolingNdhwcSumOp Index: mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp =================================================================== --- mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -96,6 +96,57 @@ results.push_back(*depthwise); return DiagnosedSilenceableFailure::success(); } + FailureOr poolingNhwcSum = tryApply< + DownscaleSizeOneWindowed2DConvolution>( + target); + if (succeeded(poolingNhwcSum)) { + results.push_back(*poolingNhwcSum); + return DiagnosedSilenceableFailure::success(); + } + FailureOr poolingNchwSum = tryApply< + DownscaleSizeOneWindowed2DConvolution>( + target); + if (succeeded(poolingNchwSum)) { + results.push_back(*poolingNchwSum); + return DiagnosedSilenceableFailure::success(); + } + FailureOr poolingNhwcMax = tryApply< + DownscaleSizeOneWindowed2DConvolution>( + target); + if (succeeded(poolingNhwcMax)) { + results.push_back(*poolingNhwcMax); + return DiagnosedSilenceableFailure::success(); + } + FailureOr poolingNhwcMaxUnsigned = + tryApply>( + target); + if (succeeded(poolingNhwcMaxUnsigned)) { + results.push_back(*poolingNhwcMaxUnsigned); + return DiagnosedSilenceableFailure::success(); + } + FailureOr poolingNhwcMin = tryApply< + DownscaleSizeOneWindowed2DConvolution>( + target); + if (succeeded(poolingNhwcMin)) { + results.push_back(*poolingNhwcMin); + return DiagnosedSilenceableFailure::success(); + } + FailureOr poolingNhwcMinUnsigned = + tryApply>( + target); + if (succeeded(poolingNhwcMinUnsigned)) { + results.push_back(*poolingNhwcMinUnsigned); + return DiagnosedSilenceableFailure::success(); + } + FailureOr poolingNchwMax = tryApply< + DownscaleSizeOneWindowed2DConvolution>( + target); + if (succeeded(poolingNchwMax)) { + results.push_back(*poolingNchwMax); + return DiagnosedSilenceableFailure::success(); + } results.assign(1, nullptr); return emitDefaultSilenceableFailure(target); } Index: mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp =================================================================== --- mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -628,8 +628,50 @@ ohIndex = 2; owIndex = 3; }) + .Case([&](linalg::PoolingNhwcSumOp op) { + khIndex = 0; + kwIndex = 1; + ohIndex = 1; + owIndex = 2; + }) + .Case([&](linalg::PoolingNchwSumOp op) { + khIndex = 0; + kwIndex = 1; + ohIndex = 2; + owIndex = 3; + }) + .Case([&](linalg::PoolingNhwcMaxOp op) { + khIndex = 0; + kwIndex = 1; + ohIndex = 1; + owIndex = 2; + }) + .Case([&](linalg::PoolingNhwcMaxUnsignedOp op) { + khIndex = 0; + kwIndex = 1; + ohIndex = 1; + owIndex = 2; + }) + .Case([&](linalg::PoolingNhwcMinOp op) { + khIndex = 0; + kwIndex = 1; + ohIndex = 1; + owIndex = 2; + }) + .Case([&](linalg::PoolingNhwcMinUnsignedOp op) { + khIndex = 0; + kwIndex = 1; + ohIndex = 1; + owIndex = 2; + }) + .Case([&](linalg::PoolingNchwMaxOp op) { + khIndex = 0; + kwIndex = 1; + ohIndex = 2; + owIndex = 3; + }) .Default([&](Operation *op) { - llvm_unreachable("unexpected conv2d operation."); + llvm_unreachable("unexpected conv2d/pool2d operation."); }); // Only handle the case where at least one of the window dimensions is @@ -688,6 +730,20 @@ Conv1DNwcWcfOp>; template struct linalg::DownscaleSizeOneWindowed2DConvolution; +template struct linalg::DownscaleSizeOneWindowed2DConvolution; +template struct linalg::DownscaleSizeOneWindowed2DConvolution; +template struct linalg::DownscaleSizeOneWindowed2DConvolution; +template struct linalg::DownscaleSizeOneWindowed2DConvolution< + PoolingNhwcMaxUnsignedOp, PoolingNwcMaxUnsignedOp>; +template struct linalg::DownscaleSizeOneWindowed2DConvolution; +template struct linalg::DownscaleSizeOneWindowed2DConvolution< + PoolingNhwcMinUnsignedOp, PoolingNwcMinUnsignedOp>; +template struct linalg::DownscaleSizeOneWindowed2DConvolution; FailureOr DownscaleDepthwiseConv2DNhwcHwcOp::returningMatchAndRewrite( @@ -765,4 +821,15 @@ Conv1DNcwFcwOp>, DownscaleDepthwiseConv2DNhwcHwcOp>(patterns.getContext(), benefit); + patterns.add< + DownscaleSizeOneWindowed2DConvolution, + DownscaleSizeOneWindowed2DConvolution, + DownscaleSizeOneWindowed2DConvolution, + DownscaleSizeOneWindowed2DConvolution, + DownscaleSizeOneWindowed2DConvolution, + DownscaleSizeOneWindowed2DConvolution, + DownscaleSizeOneWindowed2DConvolution>( + patterns.getContext(), benefit); } Index: mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp =================================================================== --- mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -385,8 +385,10 @@ [&](auto op) { return CombiningKind::ADD; }) .Case([&](auto op) { return CombiningKind::AND; }) .Case([&](auto op) { return CombiningKind::MAXSI; }) + .Case([&](auto op) { return CombiningKind::MAXUI; }) .Case([&](auto op) { return CombiningKind::MAXF; }) .Case([&](auto op) { return CombiningKind::MINSI; }) + .Case([&](auto op) { return CombiningKind::MINUI; }) .Case([&](auto op) { return CombiningKind::MINF; }) .Case( [&](auto op) { return CombiningKind::MUL; }) @@ -1838,41 +1840,75 @@ resShapedType = resShaped.getType().dyn_cast(); if (!lhsShapedType || !rhsShapedType || !resShapedType) return; - if (lhsShapedType.getRank() != 3 || - (rhsShapedType.getRank() != 2 && rhsShapedType.getRank() != 3) || - resShapedType.getRank() != 3) + if (lhsShapedType.getRank() != 3) return; - // Check for reduction `add` preceded by `mul`. Operation *reduceOp = matchLinalgReduction(linalgOp.getDpsInitOperand(0)); if (!reduceOp) return; - llvm::Optional maybeKind; - maybeKind = getCombinerOpKind(reduceOp); - if (!maybeKind || *maybeKind != vector::CombiningKind::ADD) - return; - // Check for single `mul` predecessor. The `mul` operands must be block - // arguments or extension of block arguments. - Operation *mulOp = nullptr; + poolRedOp = reduceOp->getName().getIdentifier(); + + // If (region has 2 ops (reduction + yield) or 3 ops (extension + reduction + // + yield) and rhs is not used) then it is the body of a pooling + // If conv, check for single `mul` predecessor. The `mul` operands must be + // block arguments or extension of block arguments. + // Otherwise, check for one or zero `ext` predecessor. The `ext` operands + // must be block arguments or extension of block arguments. + Operation *feedOp = nullptr; + bool isOnlyContinue = true; for (Value operand : reduceOp->getOperands()) { - if (operand.isa()) + if (operand.isa()) { continue; - if (mulOp) - return; - mulOp = operand.getDefiningOp(); - if (!mulOp || !isa(mulOp)) + } + isOnlyContinue = false; + if (feedOp) { return; - } - if (!mulOp) - return; - for (Value operand : mulOp->getOperands()) { - if (Operation *def = operand.getDefiningOp()) { - if (!isa(def)) + } + feedOp = operand.getDefiningOp(); + if (feedOp) { + if (isa(feedOp)) { + for (Value localOperand : feedOp->getOperands()) { + if (Operation *def = localOperand.getDefiningOp()) { + if (!(localOperand.isa() || + (isa(def) && + def->getOperand(0).isa()))) { + return; + } + } + } + } else if (isa(feedOp) && + feedOp->getOperand(0).isa()) { + isPool = true; + poolExtOp = feedOp->getName().getIdentifier(); + isPoolExt = true; + } else { return; - operand = def->getOperand(0); + } } - if (!operand.isa()) + } + if (isOnlyContinue) { + isPool = true; + } + llvm::Optional maybeKind = + getCombinerOpKind(reduceOp); + if (!(maybeKind && + (*maybeKind == vector::CombiningKind::ADD || + (isPool && (*maybeKind == vector::CombiningKind::MAXF || + *maybeKind == vector::CombiningKind::MAXSI || + *maybeKind == vector::CombiningKind::MAXUI || + *maybeKind == vector::CombiningKind::MINF || + *maybeKind == vector::CombiningKind::MINSI || + *maybeKind == vector::CombiningKind::MINUI))))) { + return; + } + if (isPool) { + if (!(rhsShapedType.getRank() == 1 && resShapedType.getRank() == 3)) return; + } else { + if (!((rhsShapedType.getRank() == 2 || rhsShapedType.getRank() == 3) && + resShapedType.getRank() == 3)) { + return; + } } // The op is now known to be valid. valid = true; @@ -1888,40 +1924,77 @@ /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is /// > 1. FailureOr conv(Conv1DOpOrder conv1DOpOrder) { - if (!valid) - return rewriter.notifyMatchFailure(op, "unvectorizable 1-D conv"); + if (!valid) { + if (isPool) + return rewriter.notifyMatchFailure(op, "unvectorizable 1-D pool"); + else + return rewriter.notifyMatchFailure(op, "unvectorizable 1-D conv"); + } int64_t nSize, wSize, cSize, kwSize, fSize; SmallVector lhsShape, rhsShape, resShape; switch (conv1DOpOrder) { case Conv1DOpOrder::Nwc: - // kernel{kw, c, f} - bindShapeDims(rhsShapedType, kwSize, cSize, fSize); - // out{n, w, f} - bindShapeDims(resShapedType, nSize, wSize); - lhsShape = {nSize, - // iw = ow * sw + kw * dw - 1 - // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14) - // Perform the proper inclusive -> exclusive -> inclusive. - ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - - 1, - cSize}; - rhsShape = {kwSize, cSize, fSize}; - resShape = {nSize, wSize, fSize}; + if (isPool) { + // kernel{kw} + bindShapeDims(rhsShapedType, kwSize); + // out{n, w, c} + bindShapeDims(resShapedType, nSize, wSize, cSize); + lhsShape = {nSize, + // iw = ow * sw + kw * dw - 1 + // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14) + // Perform the proper inclusive -> exclusive -> inclusive. + ((wSize - 1) * strideW + 1) + + ((kwSize - 1) * dilationW + 1) - 1, + cSize}; + rhsShape = {kwSize}; + fSize = cSize; + resShape = {nSize, wSize, cSize}; + } else { + // kernel{kw, c, f} + bindShapeDims(rhsShapedType, kwSize, cSize, fSize); + // out{n, w, f} + bindShapeDims(resShapedType, nSize, wSize); + lhsShape = {nSize, + // iw = ow * sw + kw * dw - 1 + // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14) + // Perform the proper inclusive -> exclusive -> inclusive. + ((wSize - 1) * strideW + 1) + + ((kwSize - 1) * dilationW + 1) - 1, + cSize}; + rhsShape = {kwSize, cSize, fSize}; + resShape = {nSize, wSize, fSize}; + } break; case Conv1DOpOrder::Ncw: - // kernel{f, c, kw} - bindShapeDims(rhsShapedType, fSize, cSize, kwSize); - // out{n, f, w} - bindShapeDims(resShapedType, nSize, fSize, wSize); - lhsShape = {nSize, cSize, - // iw = ow * sw + kw * dw - 1 - // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14) - // Perform the proper inclusive -> exclusive -> inclusive. - ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - - 1}; - rhsShape = {fSize, cSize, kwSize}; - resShape = {nSize, fSize, wSize}; + if (isPool) { + // kernel{kw} + bindShapeDims(rhsShapedType, kwSize); + // out{n, c, w} + bindShapeDims(resShapedType, nSize, cSize, wSize); + lhsShape = {nSize, cSize, + // iw = ow * sw + kw * dw - 1 + // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14) + // Perform the proper inclusive -> exclusive -> inclusive. + ((wSize - 1) * strideW + 1) + + ((kwSize - 1) * dilationW + 1) - 1}; + rhsShape = {kwSize}; + fSize = cSize; + resShape = {nSize, cSize, wSize}; + } else { + // kernel{f, c, kw} + bindShapeDims(rhsShapedType, fSize, cSize, kwSize); + // out{n, f, w} + bindShapeDims(resShapedType, nSize, fSize, wSize); + lhsShape = {nSize, cSize, + // iw = ow * sw + kw * dw - 1 + // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14) + // Perform the proper inclusive -> exclusive -> inclusive. + ((wSize - 1) * strideW + 1) + + ((kwSize - 1) * dilationW + 1) - 1}; + rhsShape = {fSize, cSize, kwSize}; + resShape = {nSize, fSize, wSize}; + } break; } @@ -1944,8 +2017,11 @@ Value lhs = rewriter.create( loc, lhsType, lhsShaped, ValueRange{zero, zero, zero}); // Read rhs slice of size {kw, c, f} @ [0, 0, 0]. - Value rhs = rewriter.create( - loc, rhsType, rhsShaped, ValueRange{zero, zero, zero}); + Value rhs = nullptr; + if (!isPool) { + rhs = rewriter.create( + loc, rhsType, rhsShaped, ValueRange{zero, zero, zero}); + } // Read res slice of size {n, w, f} @ [0, 0, 0]. Value res = rewriter.create( loc, resType, resShaped, ValueRange{zero, zero, zero}); @@ -1964,7 +2040,10 @@ lhs = rewriter.create(loc, lhs, permLhs); // fcw -> wcf static constexpr std::array permRhs = {2, 1, 0}; - rhs = rewriter.create(loc, rhs, permRhs); + + // Do not do for pooling + if (!isPool) + rhs = rewriter.create(loc, rhs, permRhs); // nfw -> nwf static constexpr std::array permRes = {0, 2, 1}; res = rewriter.create(loc, res, permRes); @@ -1989,8 +2068,10 @@ } // Extract rhs slice of size {c, f} @ [kw]. for (int64_t kw = 0; kw < kwSize; ++kw) { - rhsVals.push_back(rewriter.create( - loc, rhs, /*offsets=*/ArrayRef{kw})); + if (!isPool) { + rhsVals.push_back(rewriter.create( + loc, rhs, /*offsets=*/ArrayRef{kw})); + } } // Extract res slice: {n, wSizeStep, f} @ [0, w, 0]. for (int64_t w = 0; w < wSize; w += wSizeStep) { @@ -2005,11 +2086,18 @@ return kw * (wSize / wSizeStep) + w; }; - // Compute contraction: O{n, w, f} += I{n, sw * w + dw * kw, c} * F{c, f} + // Compute contraction: O{n, w, f} += I{n, sw * w + dw * kw, c} * F{c, f} or + // Perform simple arith operation for pooling for (int64_t kw = 0; kw < kwSize; ++kw) { for (int64_t w = 0; w < wSize; w += wSizeStep) { - resVals[w] = conv1dSliceAsContraction( - rewriter, loc, lhsVals[linearIndex(kw, w)], rhsVals[kw], resVals[w]); + if (isPool) { + resVals[w] = pool1dSlice(rewriter, loc, lhsVals[linearIndex(kw, w)], + resVals[w]); + } else { + resVals[w] = conv1dSliceAsContraction(rewriter, loc, + lhsVals[linearIndex(kw, w)], + rhsVals[kw], resVals[w]); + } } } @@ -2060,6 +2148,17 @@ /*iteratorTypes=*/ArrayRef{par, par, par, red}); } + // Create a reduction: lhs{n, w, c} -> res{n, w, c} + Value pool1dSlice(RewriterBase &rewriter, Location loc, Value lhs, + Value res) { + if (isPoolExt) { + lhs = rewriter.create(loc, poolExtOp, lhs, res.getType())->getResult(0); + } + return rewriter + .create(loc, poolRedOp, ArrayRef{lhs, res}, res.getType()) + ->getResult(0); + } + /// Generate a vector implementation for: /// ``` /// Op def: ( n, w, c, kw) @@ -2236,6 +2335,7 @@ /*rhsIndex*/ {kw, c, f}, /*resIndex*/ {n, w, f}})) return conv(Conv1DOpOrder::Nwc); + return rewriter.notifyMatchFailure(op, "not a conv::Nwc layout"); } @@ -2256,6 +2356,41 @@ return rewriter.notifyMatchFailure(op, "not a conv::Ncw layout"); } + /// Entry point that transposes into the common form: + /// {{n, strideW * w + dilationW * kw, c}, {kw}, {n, w, c}} for pooling + FailureOr generateNwcPooling() { + AffineExpr n, w, c, kw; + bindDims(ctx, n, w, c, kw); + if (!iters({Par(), Par(), Par(), Red()})) + return rewriter.notifyMatchFailure(op, + "failed to match pooling 3-par 1-red"); + + // No transposition needed. + if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c}, + /*rhsIndex*/ {kw}, + /*resIndex*/ {n, w, c}})) + return conv(Conv1DOpOrder::Nwc); + + return rewriter.notifyMatchFailure(op, "not a pooling::Nwc layout"); + } + + /// Entry point that transposes into the common form: + /// {{n, c, strideW * w + dilationW * kw}, {kw}, {n, c, w}} for pooling + FailureOr generateNcwPooling() { + AffineExpr n, w, c, kw; + bindDims(ctx, n, c, w, kw); + if (!iters({Par(), Par(), Par(), Red()})) + return rewriter.notifyMatchFailure(op, + "failed to match pooling 3-par 1-red"); + + if (layout({/*lhsIndex*/ {n, c, strideW * w + dilationW * kw}, + /*rhsIndex*/ {kw}, + /*resIndex*/ {n, c, w}})) + return conv(Conv1DOpOrder::Ncw); + + return rewriter.notifyMatchFailure(op, "not a pooling::Ncw layout"); + } + /// Entry point that transposes into the common form: /// {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}} FailureOr generateDilatedConv() { @@ -2276,6 +2411,10 @@ private: bool valid = false; + bool isPool = false; + StringAttr poolRedOp; + StringAttr poolExtOp; + bool isPoolExt = false; int strideW, dilationW; Value lhsShaped, rhsShaped, resShaped; ShapedType lhsShapedType, rhsShapedType, resShapedType; @@ -2299,6 +2438,12 @@ if (succeeded(res)) return res; res = e.generateNcwConv(); + if (succeeded(res)) + return res; + res = e.generateNwcPooling(); + if (succeeded(res)) + return res; + res = e.generateNcwPooling(); if (succeeded(res)) return res; return e.generateDilatedConv(); Index: mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py =================================================================== --- mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -694,7 +694,6 @@ D.ow * S.SW + D.kw * S.DW, D.ic]) * TypeFn.cast_signed( U, K[D.kd, D.kh, D.kw, D.ic, D.cm]) - @linalg_structured_op def pooling_nhwc_sum(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C), @@ -838,6 +837,146 @@ D.c] = ReduceFn.min_unsigned[D.kh, D.kw](TypeFn.cast_unsigned( U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])) +@linalg_structured_op +def pooling_nwc_sum(I=TensorDef(T1, S.N, + S.OW * S.SW + S.KW * S.DW, S.C), + K=TensorDef(T2, S.KW, index_dims=[D.kw]), + O=TensorDef(U, S.N, S.OW, S.C, output=True), + strides=IndexAttrDef(S.SW, default=[1]), + dilations=IndexAttrDef(S.DW, default=[1])): + """Performs sum pooling. + + Layout: + * Input: NWC. + * Kernel: W. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.ow, D.c, D.kw) + O[D.n, D.ow, D.c] += TypeFn.cast_signed( + U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c]) + + +@linalg_structured_op +def pooling_ncw_sum(I=TensorDef(T1, S.N, S.C, + S.OW * S.SW + S.KW * S.DW), + K=TensorDef(T2, S.KW, index_dims=[D.kw]), + O=TensorDef(U, S.N, S.C, S.OW, output=True), + strides=IndexAttrDef(S.SW, default=[1]), + dilations=IndexAttrDef(S.DW, default=[1])): + """Performs sum pooling. + + Layout: + * Input: NCW. + * Kernel: W. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.c, D.ow, D.kw) + O[D.n, D.c, D.ow] += TypeFn.cast_signed( + U, I[D.n, D.c, D.ow * S.SW + D.kw * S.DW]) + + +@linalg_structured_op +def pooling_nwc_max(I=TensorDef(T1, S.N, + S.OW * S.SW + S.KW * S.DW, S.C), + K=TensorDef(T2, S.KW, index_dims=[D.kw]), + O=TensorDef(U, S.N, S.OW, S.C, output=True), + strides=IndexAttrDef(S.SW, default=[1]), + dilations=IndexAttrDef(S.DW, default=[1])): + """Performs max pooling. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.ow, D.c, D.kw) + O[D.n, D.ow, D.c] = ReduceFn.max_signed[[D.kw]](TypeFn.cast_signed( + U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c])) + + +@linalg_structured_op +def pooling_nwc_max_unsigned(I=TensorDef(T1, S.N, + S.OW * S.SW + S.KW * S.DW, S.C), + K=TensorDef(T2, + S.KW, + index_dims=[D.kw]), + O=TensorDef(U, S.N, S.OW, S.C, output=True), + strides=IndexAttrDef(S.SW, default=[1]), + dilations=IndexAttrDef(S.DW, default=[1])): + """Performs unsigned max pooling. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.ow, D.c, D.kw) + O[D.n, D.ow, + D.c] = ReduceFn.max_unsigned[[D.kw]](TypeFn.cast_unsigned( + U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c])) + + +@linalg_structured_op +def pooling_ncw_max(I=TensorDef(T1, S.N, S.C, + S.OW * S.SW + S.KW * S.DW), + K=TensorDef(T2, S.KW, index_dims=[D.kw]), + O=TensorDef(U, S.N, S.C, S.OW, output=True), + strides=IndexAttrDef(S.SW, default=[1]), + dilations=IndexAttrDef(S.DW, default=[1])): + """Performs max pooling. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.c, D.ow, D.kw) + O[D.n, D.c, D.ow] = ReduceFn.max_signed[[D.kw]](TypeFn.cast_signed( + U, I[D.n, D.c, D.ow * S.SW + D.kw * S.DW,])) + + +@linalg_structured_op +def pooling_nwc_min(I=TensorDef(T1, S.N, + S.OW * S.SW + S.KW * S.DW, S.C), + K=TensorDef(T2, S.KW, index_dims=[D.kw]), + O=TensorDef(U, S.N, S.OW, S.C, output=True), + strides=IndexAttrDef(S.SW, default=[1]), + dilations=IndexAttrDef(S.DW, default=[1])): + """Performs min pooling. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.ow, D.c, D.kw) + O[D.n, D.ow, D.c] = ReduceFn.min_signed[[D.kw]](TypeFn.cast_signed( + U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c])) + + +@linalg_structured_op +def pooling_nwc_min_unsigned(I=TensorDef(T1, S.N, + S.OW * S.SW + S.KW * S.DW, S.C), + K=TensorDef(T2, + S.KW, + index_dims=[D.kw]), + O=TensorDef(U, S.N, S.OW, S.C, output=True), + strides=IndexAttrDef(S.SW, default=[1]), + dilations=IndexAttrDef(S.DW, default=[1])): + """Performs unsigned min pooling. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.ow, D.c, D.kw) + O[D.n, D.ow, + D.c] = ReduceFn.min_unsigned[[D.kw]](TypeFn.cast_unsigned( + U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c])) + + @linalg_structured_op def pooling_ndhwc_sum(I=TensorDef(T1, S.N, S.OD * S.SD + S.KD * S.DD, Index: mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir =================================================================== --- mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir +++ mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir @@ -131,6 +131,20 @@ // ----- +func.func @generalize_pooling_nwc_max_f32(%input : tensor<1x16x1xf32>, %shape: tensor<2xf32>, %output: tensor<1x4x1xf32>) -> tensor<1x4x1xf32> { + %0 = linalg.pooling_nwc_max {dilations = dense<[2]> : tensor<1xi64>, strides = dense<[4]> : tensor<1xi64>} + ins(%input, %shape : tensor<1x16x1xf32>, tensor<2xf32>) outs(%output : tensor<1x4x1xf32>) -> tensor<1x4x1xf32> + return %0: tensor<1x4x1xf32> +} + +// CHECK-LABEL: @generalize_pooling_nwc_max_f32 +// CHECK: ^{{.*}}(%[[IN_ARG:.+]]: f32, %[[SHAPE_ARG:.+]]: f32, %[[OUT_ARG:.+]]: f32) +// CHECK-NEXT: %[[MAX:.+]] = arith.maxf %[[OUT_ARG]], %[[IN_ARG]] : f32 +// CHECK-NEXT: linalg.yield %[[MAX]] : f32 +// CHECK-NEXT: -> tensor<1x4x1xf32> + +// ----- + func.func @generalize_pooling_nhwc_max_i32(%input : tensor<1x4x16x1xi32>, %shape: tensor<2x2xi32>, %output: tensor<1x2x4x1xi32>) -> tensor<1x2x4x1xi32> { %0 = linalg.pooling_nhwc_max {dilations = dense<[1, 2]> : tensor<2xi64>, strides = dense<[2, 4]> : tensor<2xi64>} ins(%input, %shape : tensor<1x4x16x1xi32>, tensor<2x2xi32>) outs(%output : tensor<1x2x4x1xi32>) -> tensor<1x2x4x1xi32> @@ -143,6 +157,18 @@ // ----- +func.func @generalize_pooling_nwc_max_i32(%input : tensor<1x16x1xi32>, %shape: tensor<2xi32>, %output: tensor<1x4x1xi32>) -> tensor<1x4x1xi32> { + %0 = linalg.pooling_nwc_max {dilations = dense<[2]> : tensor<1xi64>, strides = dense<[4]> : tensor<1xi64>} + ins(%input, %shape : tensor<1x16x1xi32>, tensor<2xi32>) outs(%output : tensor<1x4x1xi32>) -> tensor<1x4x1xi32> + return %0: tensor<1x4x1xi32> +} + +// CHECK-LABEL: @generalize_pooling_nwc_max_i32 +// Verify signed integer maximum. +// CHECK: = arith.maxsi + +// ----- + func.func @generalize_pooling_nhwc_max_unsigned_i32(%input : tensor<1x4x16x1xi32>, %shape: tensor<2x2xi32>, %output: tensor<1x2x4x1xi32>) -> tensor<1x2x4x1xi32> { %0 = linalg.pooling_nhwc_max_unsigned {dilations = dense<[1, 2]> : tensor<2xi64>, strides = dense<[2, 4]> : tensor<2xi64>} ins(%input, %shape : tensor<1x4x16x1xi32>, tensor<2x2xi32>) outs(%output : tensor<1x2x4x1xi32>) -> tensor<1x2x4x1xi32> @@ -155,6 +181,18 @@ // ----- +func.func @generalize_pooling_nwc_max_unsigned_i32(%input : tensor<1x16x1xi32>, %shape: tensor<2xi32>, %output: tensor<1x4x1xi32>) -> tensor<1x4x1xi32> { + %0 = linalg.pooling_nwc_max_unsigned {dilations = dense<[2]> : tensor<1xi64>, strides = dense<[4]> : tensor<1xi64>} + ins(%input, %shape : tensor<1x16x1xi32>, tensor<2xi32>) outs(%output : tensor<1x4x1xi32>) -> tensor<1x4x1xi32> + return %0: tensor<1x4x1xi32> +} + +// CHECK-LABEL: @generalize_pooling_nwc_max_unsigned_i32 +// Verify unsigned integer minimum. +// CHECK: = arith.maxui + +// ----- + func.func @generalize_pooling_nhwc_min_f32(%input : tensor<1x4x16x1xf32>, %shape: tensor<2x2xf32>, %output: tensor<1x2x4x1xf32>) -> tensor<1x2x4x1xf32> { %0 = linalg.pooling_nhwc_min {dilations = dense<[1, 2]> : tensor<2xi64>, strides = dense<[2, 4]> : tensor<2xi64>} ins(%input, %shape : tensor<1x4x16x1xf32>, tensor<2x2xf32>) outs(%output : tensor<1x2x4x1xf32>) -> tensor<1x2x4x1xf32> @@ -169,6 +207,20 @@ // ----- +func.func @generalize_pooling_nwc_min_f32(%input : tensor<1x16x1xf32>, %shape: tensor<2xf32>, %output: tensor<1x4x1xf32>) -> tensor<1x4x1xf32> { + %0 = linalg.pooling_nwc_min {dilations = dense<[2]> : tensor<1xi64>, strides = dense<[4]> : tensor<1xi64>} + ins(%input, %shape : tensor<1x16x1xf32>, tensor<2xf32>) outs(%output : tensor<1x4x1xf32>) -> tensor<1x4x1xf32> + return %0: tensor<1x4x1xf32> +} + +// CHECK-LABEL: @generalize_pooling_nwc_min_f32 +// CHECK: ^{{.*}}(%[[IN_ARG:.+]]: f32, %[[SHAPE_ARG:.+]]: f32, %[[OUT_ARG:.+]]: f32) +// CHECK-NEXT: %[[MIN:.+]] = arith.minf %[[OUT_ARG]], %[[IN_ARG]] : f32 +// CHECK-NEXT: linalg.yield %[[MIN]] : f32 +// CHECK-NEXT: -> tensor<1x4x1xf32> + +// ----- + func.func @generalize_pooling_nhwc_min_i32(%input : tensor<1x4x16x1xi32>, %shape: tensor<2x2xi32>, %output: tensor<1x2x4x1xi32>) -> tensor<1x2x4x1xi32> { %0 = linalg.pooling_nhwc_min {dilations = dense<[1, 2]> : tensor<2xi64>, strides = dense<[2, 4]> : tensor<2xi64>} ins(%input, %shape : tensor<1x4x16x1xi32>, tensor<2x2xi32>) outs(%output : tensor<1x2x4x1xi32>) -> tensor<1x2x4x1xi32> @@ -181,6 +233,18 @@ // ----- +func.func @generalize_pooling_nwc_min_i32(%input : tensor<1x16x1xi32>, %shape: tensor<2xi32>, %output: tensor<1x4x1xi32>) -> tensor<1x4x1xi32> { + %0 = linalg.pooling_nwc_min {dilations = dense<[2]> : tensor<1xi64>, strides = dense<[4]> : tensor<1xi64>} + ins(%input, %shape : tensor<1x16x1xi32>, tensor<2xi32>) outs(%output : tensor<1x4x1xi32>) -> tensor<1x4x1xi32> + return %0: tensor<1x4x1xi32> +} + +// CHECK-LABEL: @generalize_pooling_nwc_min_i32 +// Verify signed integer minimum. +// CHECK: = arith.minsi + +// ----- + func.func @generalize_pooling_nhwc_min_unsigned_i32(%input : tensor<1x4x16x1xi32>, %shape: tensor<2x2xi32>, %output: tensor<1x2x4x1xi32>) -> tensor<1x2x4x1xi32> { %0 = linalg.pooling_nhwc_min_unsigned {dilations = dense<[1, 2]> : tensor<2xi64>, strides = dense<[2, 4]> : tensor<2xi64>} ins(%input, %shape : tensor<1x4x16x1xi32>, tensor<2x2xi32>) outs(%output : tensor<1x2x4x1xi32>) -> tensor<1x2x4x1xi32> @@ -193,6 +257,18 @@ // ----- +func.func @generalize_pooling_nwc_min_unsigned_i32(%input : tensor<1x16x1xi32>, %shape: tensor<2xi32>, %output: tensor<1x4x1xi32>) -> tensor<1x4x1xi32> { + %0 = linalg.pooling_nwc_min_unsigned {dilations = dense<[2]> : tensor<1xi64>, strides = dense<[4]> : tensor<1xi64>} + ins(%input, %shape : tensor<1x16x1xi32>, tensor<2xi32>) outs(%output : tensor<1x4x1xi32>) -> tensor<1x4x1xi32> + return %0: tensor<1x4x1xi32> +} + +// CHECK-LABEL: @generalize_pooling_nwc_min_unsigned_i32 +// Verify unsigned integer minimum. +// CHECK: = arith.minui + +// ----- + func.func @generalize_pooling_nhwc_sum_f32(%input : tensor<1x4x16x1xf32>, %shape: tensor<2x2xf32>, %output: tensor<1x2x4x1xf32>) -> tensor<1x2x4x1xf32> { %0 = linalg.pooling_nhwc_sum {dilations = dense<[1, 2]> : tensor<2xi64>, strides = dense<[2, 4]> : tensor<2xi64>} ins(%input, %shape : tensor<1x4x16x1xf32>, tensor<2x2xf32>) outs(%output : tensor<1x2x4x1xf32>) -> tensor<1x2x4x1xf32> @@ -207,6 +283,20 @@ // ----- +func.func @generalize_pooling_nwc_sum_f32(%input : tensor<1x16x1xf32>, %shape: tensor<2xf32>, %output: tensor<1x4x1xf32>) -> tensor<1x4x1xf32> { + %0 = linalg.pooling_nwc_sum {dilations = dense<[2]> : tensor<1xi64>, strides = dense<[4]> : tensor<1xi64>} + ins(%input, %shape : tensor<1x16x1xf32>, tensor<2xf32>) outs(%output : tensor<1x4x1xf32>) -> tensor<1x4x1xf32> + return %0: tensor<1x4x1xf32> +} + +// CHECK-LABEL: @generalize_pooling_nwc_sum_f32 +// CHECK: ^{{.*}}(%[[IN_ARG:.+]]: f32, %[[SHAPE_ARG:.+]]: f32, %[[OUT_ARG:.+]]: f32) +// CHECK-NEXT: %[[ADD:.+]] = arith.addf %[[OUT_ARG]], %[[IN_ARG]] : f32 +// CHECK-NEXT: linalg.yield %[[ADD]] : f32 +// CHECK-NEXT: -> tensor<1x4x1xf32> + +// ----- + func.func @generalize_pooling_nhwc_sum_i32(%input : tensor<1x4x16x1xi32>, %shape: tensor<2x2xi32>, %output: tensor<1x2x4x1xi32>) -> tensor<1x2x4x1xi32> { %0 = linalg.pooling_nhwc_sum {dilations = dense<[1, 2]> : tensor<2xi64>, strides = dense<[2, 4]> : tensor<2xi64>} ins(%input, %shape : tensor<1x4x16x1xi32>, tensor<2x2xi32>) outs(%output : tensor<1x2x4x1xi32>) -> tensor<1x2x4x1xi32> @@ -221,6 +311,20 @@ // ----- +func.func @generalize_pooling_nwc_sum_i32(%input : tensor<1x16x1xi32>, %shape: tensor<2xi32>, %output: tensor<1x4x1xi32>) -> tensor<1x4x1xi32> { + %0 = linalg.pooling_nwc_sum {dilations = dense<[2]> : tensor<1xi64>, strides = dense<[4]> : tensor<1xi64>} + ins(%input, %shape : tensor<1x16x1xi32>, tensor<2xi32>) outs(%output : tensor<1x4x1xi32>) -> tensor<1x4x1xi32> + return %0: tensor<1x4x1xi32> +} + +// CHECK-LABEL: @generalize_pooling_nwc_sum_i32 +// CHECK: ^{{.*}}(%[[IN_ARG:.+]]: i32, %[[SHAPE_ARG:.+]]: i32, %[[OUT_ARG:.+]]: i32) +// CHECK-NEXT: %[[ADD:.+]] = arith.addi %[[OUT_ARG]], %[[IN_ARG]] : i32 +// CHECK-NEXT: linalg.yield %[[ADD]] : i32 +// CHECK-NEXT: -> tensor<1x4x1xi32> + +// ----- + func.func @generalize_fill_0d(%value: f64, %O: tensor) -> tensor { %0 = linalg.fill ins(%value: f64) outs(%O : tensor) -> tensor return %0: tensor Index: mlir/test/Dialect/Linalg/named-ops.mlir =================================================================== --- mlir/test/Dialect/Linalg/named-ops.mlir +++ mlir/test/Dialect/Linalg/named-ops.mlir @@ -422,6 +422,25 @@ // ----- +// CHECK-LABEL: func @pooling_nwc_sum_tensor +// CHECK: %{{.+}} = linalg.pooling_nwc_sum +// CHECK-SAME: dilations = dense<1> : tensor<1xi64> +// CHECK-SAME: strides = dense<1> : tensor<1xi64> +// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<1x4x1xf32>, tensor<3xf32>) +// CHECK-SAME: outs(%{{.+}} : tensor<1x2x1xf32>) -> tensor<1x2x1xf32> +func.func @pooling_nwc_sum_tensor(%input: tensor<1x4x1xf32>) -> tensor<1x2x1xf32> { + %fake = tensor.empty() : tensor<3xf32> + %init = tensor.empty() : tensor<1x2x1xf32> + %cst = arith.constant 0.000000e+00 : f32 + %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<1x2x1xf32>) -> tensor<1x2x1xf32> + %res = linalg.pooling_nwc_sum {dilations = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} + ins(%input, %fake: tensor<1x4x1xf32>, tensor<3xf32>) + outs(%fill: tensor<1x2x1xf32>) -> tensor<1x2x1xf32> + return %res : tensor<1x2x1xf32> +} + +// ----- + // CHECK-LABEL: func @pooling_nhwc_sum // CHECK: linalg.pooling_nhwc_sum // CHECK-SAME: dilations = dense<1> : tensor<2xi64> @@ -437,6 +456,21 @@ // ----- +// CHECK-LABEL: func @pooling_nwc_sum +// CHECK: linalg.pooling_nwc_sum +// CHECK-SAME: dilations = dense<1> : tensor<1xi64> +// CHECK-SAME: strides = dense<1> : tensor<1xi64> +// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<1x4x1xf32>, memref<3xf32>) +// CHECK-SAME: outs(%{{.+}} : memref<1x2x1xf32>) +func.func @pooling_nwc_sum(%input: memref<1x4x1xf32>, %fake: memref<3xf32>, %output: memref<1x2x1xf32>) { + linalg.pooling_nwc_sum {dilations = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} + ins(%input, %fake: memref<1x4x1xf32>, memref<3xf32>) + outs(%output: memref<1x2x1xf32>) + return +} + +// ----- + // CHECK-LABEL: func @pooling_nchw_sum_tensor // CHECK: %{{.+}} = linalg.pooling_nchw_sum // CHECK-SAME: dilations = dense<1> : tensor<2xi64> @@ -456,6 +490,25 @@ // ----- +// CHECK-LABEL: func @pooling_ncw_sum_tensor +// CHECK: %{{.+}} = linalg.pooling_ncw_sum +// CHECK-SAME: dilations = dense<1> : tensor<1xi64> +// CHECK-SAME: strides = dense<1> : tensor<1xi64> +// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<1x1x4xf32>, tensor<3xf32>) +// CHECK-SAME: outs(%{{.+}} : tensor<1x1x2xf32>) -> tensor<1x1x2xf32> +func.func @pooling_ncw_sum_tensor(%input: tensor<1x1x4xf32>) -> tensor<1x1x2xf32> { + %fake = tensor.empty() : tensor<3xf32> + %init = tensor.empty() : tensor<1x1x2xf32> + %cst = arith.constant 0.000000e+00 : f32 + %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<1x1x2xf32>) -> tensor<1x1x2xf32> + %res = linalg.pooling_ncw_sum {dilations = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} + ins(%input, %fake: tensor<1x1x4xf32>, tensor<3xf32>) + outs(%fill: tensor<1x1x2xf32>) -> tensor<1x1x2xf32> + return %res : tensor<1x1x2xf32> +} + +// ----- + // CHECK-LABEL: func @pooling_nchw_sum // CHECK: linalg.pooling_nchw_sum // CHECK-SAME: dilations = dense<1> : tensor<2xi64> @@ -471,6 +524,21 @@ // ----- +// CHECK-LABEL: func @pooling_ncw_sum +// CHECK: linalg.pooling_ncw_sum +// CHECK-SAME: dilations = dense<1> : tensor<1xi64> +// CHECK-SAME: strides = dense<1> : tensor<1xi64> +// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<1x1x4xf32>, memref<3xf32>) +// CHECK-SAME: outs(%{{.+}} : memref<1x1x2xf32>) +func.func @pooling_ncw_sum(%input: memref<1x1x4xf32>, %fake: memref<3xf32>, %output: memref<1x1x2xf32>) { + linalg.pooling_ncw_sum {dilations = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} + ins(%input, %fake: memref<1x1x4xf32>, memref<3xf32>) + outs(%output: memref<1x1x2xf32>) + return +} + +// ----- + // CHECK-LABEL: func @pooling_nhwc_max_tensor // CHECK: %{{.+}} = linalg.pooling_nhwc_max // CHECK-SAME: dilations = dense<1> : tensor<2xi64> @@ -488,6 +556,24 @@ return %res : tensor<1x2x2x1xf32> } +// ----- +// CHECK-LABEL: func @pooling_nwc_max_tensor +// CHECK: %{{.+}} = linalg.pooling_nwc_max +// CHECK-SAME: dilations = dense<1> : tensor<1xi64> +// CHECK-SAME: strides = dense<1> : tensor<1xi64> +// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<1x4x1xf32>, tensor<3xf32>) +// CHECK-SAME: outs(%{{.+}} : tensor<1x2x1xf32>) -> tensor<1x2x1xf32> +func.func @pooling_nwc_max_tensor(%input: tensor<1x4x1xf32>) -> tensor<1x2x1xf32> { + %fake = tensor.empty() : tensor<3xf32> + %init = tensor.empty() : tensor<1x2x1xf32> + %cst = arith.constant 0.000000e+00 : f32 + %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<1x2x1xf32>) -> tensor<1x2x1xf32> + %res = linalg.pooling_nwc_max {dilations = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} + ins(%input, %fake: tensor<1x4x1xf32>, tensor<3xf32>) + outs(%fill: tensor<1x2x1xf32>) -> tensor<1x2x1xf32> + return %res : tensor<1x2x1xf32> +} + // ----- // CHECK-LABEL: func @pooling_nchw_max_tensor // CHECK: %{{.+}} = linalg.pooling_nchw_max @@ -507,6 +593,25 @@ return %res : tensor<1x1x2x2xf32> } +// ----- +// CHECK-LABEL: func @pooling_ncw_max_tensor +// CHECK: %{{.+}} = linalg.pooling_ncw_max +// CHECK-SAME: dilations = dense<1> : tensor<1xi64> +// CHECK-SAME: strides = dense<1> : tensor<1xi64> +// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<1x1x4xf32>, tensor<3xf32>) +// CHECK-SAME: outs(%{{.+}} : tensor<1x1x2xf32>) -> tensor<1x1x2xf32> + +func.func @pooling_ncw_max_tensor(%input: tensor<1x1x4xf32>) -> tensor<1x1x2xf32> { + %fake = tensor.empty() : tensor<3xf32> + %init = tensor.empty() : tensor<1x1x2xf32> + %cst = arith.constant 0.000000e+00 : f32 + %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<1x1x2xf32>) -> tensor<1x1x2xf32> + %res = linalg.pooling_ncw_max {dilations = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} + ins(%input, %fake: tensor<1x1x4xf32>, tensor<3xf32>) + outs(%fill: tensor<1x1x2xf32>) -> tensor<1x1x2xf32> + return %res : tensor<1x1x2xf32> +} + // ----- // CHECK-LABEL: func @pooling_nhwc_max @@ -524,6 +629,21 @@ // ----- +// CHECK-LABEL: func @pooling_nwc_max +// CHECK: linalg.pooling_nwc_max +// CHECK-SAME: dilations = dense<1> : tensor<1xi64> +// CHECK-SAME: strides = dense<1> : tensor<1xi64> +// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<1x4x1xf32>, memref<3xf32>) +// CHECK-SAME: outs(%{{.+}} : memref<1x2x1xf32>) +func.func @pooling_nwc_max(%input: memref<1x4x1xf32>, %fake: memref<3xf32>, %output: memref<1x2x1xf32>) { + linalg.pooling_nwc_max {dilations = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} + ins(%input, %fake: memref<1x4x1xf32>, memref<3xf32>) + outs(%output: memref<1x2x1xf32>) + return +} + +// ----- + // CHECK-LABEL: func @pooling_nhwc_i8_max_tensor // CHECK: %{{.+}} = linalg.pooling_nhwc_max // CHECK-SAME: dilations = dense<1> : tensor<2xi64> @@ -543,6 +663,25 @@ // ----- +// CHECK-LABEL: func @pooling_nwc_i8_max_tensor +// CHECK: %{{.+}} = linalg.pooling_nwc_max +// CHECK-SAME: dilations = dense<1> : tensor<1xi64> +// CHECK-SAME: strides = dense<1> : tensor<1xi64> +// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<1x4x1xi8>, tensor<3xi8>) +// CHECK-SAME: outs(%{{.+}} : tensor<1x2x1xi8>) -> tensor<1x2x1xi8> +func.func @pooling_nwc_i8_max_tensor(%input: tensor<1x4x1xi8>) -> tensor<1x2x1xi8> { + %fake = tensor.empty() : tensor<3xi8> + %init = tensor.empty() : tensor<1x2x1xi8> + %cst = arith.constant 0 : i8 + %fill = linalg.fill ins(%cst : i8) outs(%init : tensor<1x2x1xi8>) -> tensor<1x2x1xi8> + %res = linalg.pooling_nwc_max {dilations = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} + ins(%input, %fake: tensor<1x4x1xi8>, tensor<3xi8>) + outs(%fill: tensor<1x2x1xi8>) -> tensor<1x2x1xi8> + return %res : tensor<1x2x1xi8> +} + +// ----- + // CHECK-LABEL: func @pooling_nhwc_i8_max // CHECK: linalg.pooling_nhwc_max // CHECK-SAME: dilations = dense<1> : tensor<2xi64> @@ -558,6 +697,21 @@ // ----- +// CHECK-LABEL: func @pooling_nwc_i8_max +// CHECK: linalg.pooling_nwc_max +// CHECK-SAME: dilations = dense<1> : tensor<1xi64> +// CHECK-SAME: strides = dense<1> : tensor<1xi64> +// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<1x4x1xi8>, memref<3xi8>) +// CHECK-SAME: outs(%{{.+}} : memref<1x2x1xi8>) +func.func @pooling_nwc_i8_max(%input: memref<1x4x1xi8>, %fake: memref<3xi8>, %output: memref<1x2x1xi8>) { + linalg.pooling_nwc_max {dilations = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} + ins(%input, %fake: memref<1x4x1xi8>, memref<3xi8>) + outs(%output: memref<1x2x1xi8>) + return +} + +// ----- + // CHECK-LABEL: func @pooling_nhwc_i16_max_tensor // CHECK: %{{.+}} = linalg.pooling_nhwc_max // CHECK-SAME: dilations = dense<1> : tensor<2xi64> @@ -577,6 +731,25 @@ // ----- +// CHECK-LABEL: func @pooling_nwc_i16_max_tensor +// CHECK: %{{.+}} = linalg.pooling_nwc_max +// CHECK-SAME: dilations = dense<1> : tensor<1xi64> +// CHECK-SAME: strides = dense<1> : tensor<1xi64> +// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<1x4x1xi16>, tensor<3xi16>) +// CHECK-SAME: outs(%{{.+}} : tensor<1x2x1xi16>) -> tensor<1x2x1xi16> +func.func @pooling_nwc_i16_max_tensor(%input: tensor<1x4x1xi16>) -> tensor<1x2x1xi16> { + %fake = tensor.empty() : tensor<3xi16> + %init = tensor.empty() : tensor<1x2x1xi16> + %cst = arith.constant 0 : i16 + %fill = linalg.fill ins(%cst : i16) outs(%init : tensor<1x2x1xi16>) -> tensor<1x2x1xi16> + %res = linalg.pooling_nwc_max {dilations = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} + ins(%input, %fake: tensor<1x4x1xi16>, tensor<3xi16>) + outs(%fill: tensor<1x2x1xi16>) -> tensor<1x2x1xi16> + return %res : tensor<1x2x1xi16> +} + +// ----- + // CHECK-LABEL: func @pooling_nhwc_i16_max // CHECK: linalg.pooling_nhwc_max // CHECK-SAME: dilations = dense<1> : tensor<2xi64> @@ -592,6 +765,21 @@ // ----- +// CHECK-LABEL: func @pooling_nwc_i16_max +// CHECK: linalg.pooling_nwc_max +// CHECK-SAME: dilations = dense<1> : tensor<1xi64> +// CHECK-SAME: strides = dense<1> : tensor<1xi64> +// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<1x4x1xi16>, memref<3xi16>) +// CHECK-SAME: outs(%{{.+}} : memref<1x2x1xi16>) +func.func @pooling_nwc_i16_max(%input: memref<1x4x1xi16>, %fake: memref<3xi16>, %output: memref<1x2x1xi16>) { + linalg.pooling_nwc_max {dilations = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} + ins(%input, %fake: memref<1x4x1xi16>, memref<3xi16>) + outs(%output: memref<1x2x1xi16>) + return +} + +// ----- + // CHECK-LABEL: func @pooling_nhwc_i32_max_tensor // CHECK: %{{.+}} = linalg.pooling_nhwc_max // CHECK-SAME: dilations = dense<1> : tensor<2xi64> @@ -611,6 +799,25 @@ // ----- +// CHECK-LABEL: func @pooling_nwc_i32_max_tensor +// CHECK: %{{.+}} = linalg.pooling_nwc_max +// CHECK-SAME: dilations = dense<1> : tensor<1xi64> +// CHECK-SAME: strides = dense<1> : tensor<1xi64> +// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<1x4x1xi32>, tensor<3xi32>) +// CHECK-SAME: outs(%{{.+}} : tensor<1x2x1xi32>) -> tensor<1x2x1xi32> +func.func @pooling_nwc_i32_max_tensor(%input: tensor<1x4x1xi32>) -> tensor<1x2x1xi32> { + %fake = tensor.empty() : tensor<3xi32> + %init = tensor.empty() : tensor<1x2x1xi32> + %cst = arith.constant 0 : i32 + %fill = linalg.fill ins(%cst : i32) outs(%init : tensor<1x2x1xi32>) -> tensor<1x2x1xi32> + %res = linalg.pooling_nwc_max {dilations = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} + ins(%input, %fake: tensor<1x4x1xi32>, tensor<3xi32>) + outs(%fill: tensor<1x2x1xi32>) -> tensor<1x2x1xi32> + return %res : tensor<1x2x1xi32> +} + +// ----- + // CHECK-LABEL: func @pooling_nhwc_i32_max // CHECK: linalg.pooling_nhwc_max // CHECK-SAME: dilations = dense<1> : tensor<2xi64> @@ -624,6 +831,21 @@ return } +// ----- + +// CHECK-LABEL: func @pooling_nwc_i32_max +// CHECK: linalg.pooling_nwc_max +// CHECK-SAME: dilations = dense<1> : tensor<1xi64> +// CHECK-SAME: strides = dense<1> : tensor<1xi64> +// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<1x4x1xi32>, memref<3xi32>) +// CHECK-SAME: outs(%{{.+}} : memref<1x2x1xi32>) +func.func @pooling_nwc_i32_max(%input: memref<1x4x1xi32>, %fake: memref<3xi32>, %output: memref<1x2x1xi32>) { + linalg.pooling_nwc_max {dilations = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} + ins(%input, %fake: memref<1x4x1xi32>, memref<3xi32>) + outs(%output: memref<1x2x1xi32>) + return +} + // ----- @@ -646,6 +868,25 @@ // ----- +// CHECK-LABEL: func @pooling_nwc_min_tensor +// CHECK: %{{.+}} = linalg.pooling_nwc_min +// CHECK-SAME: dilations = dense<1> : tensor<1xi64> +// CHECK-SAME: strides = dense<1> : tensor<1xi64> +// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<1x4x1xf32>, tensor<3xf32>) +// CHECK-SAME: outs(%{{.+}} : tensor<1x2x1xf32>) -> tensor<1x2x1xf32> +func.func @pooling_nwc_min_tensor(%input: tensor<1x4x1xf32>) -> tensor<1x2x1xf32> { + %fake = tensor.empty() : tensor<3xf32> + %init = tensor.empty() : tensor<1x2x1xf32> + %cst = arith.constant 0.000000e+00 : f32 + %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<1x2x1xf32>) -> tensor<1x2x1xf32> + %res = linalg.pooling_nwc_min {dilations = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} + ins(%input, %fake: tensor<1x4x1xf32>, tensor<3xf32>) + outs(%fill: tensor<1x2x1xf32>) -> tensor<1x2x1xf32> + return %res : tensor<1x2x1xf32> +} + +// ----- + // CHECK-LABEL: func @pooling_nhwc_min // CHECK: linalg.pooling_nhwc_min // CHECK-SAME: dilations = dense<1> : tensor<2xi64> @@ -661,6 +902,21 @@ // ----- +// CHECK-LABEL: func @pooling_nwc_min +// CHECK: linalg.pooling_nwc_min +// CHECK-SAME: dilations = dense<1> : tensor<1xi64> +// CHECK-SAME: strides = dense<1> : tensor<1xi64> +// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<1x4x1xf32>, memref<3xf32>) +// CHECK-SAME: outs(%{{.+}} : memref<1x2x1xf32>) +func.func @pooling_nwc_min(%input: memref<1x4x1xf32>, %fake: memref<3xf32>, %output: memref<1x2x1xf32>) { + linalg.pooling_nwc_min {dilations = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} + ins(%input, %fake: memref<1x4x1xf32>, memref<3xf32>) + outs(%output: memref<1x2x1xf32>) + return +} + +// ----- + // CHECK-LABEL: func @pooling_ndhwc_sum_tensor // CHECK: %{{.+}} = linalg.pooling_ndhwc_sum // CHECK-SAME: dilations = dense<1> : tensor<3xi64> Index: mlir/test/Dialect/Linalg/transform-op-decompose.mlir =================================================================== --- mlir/test/Dialect/Linalg/transform-op-decompose.mlir +++ mlir/test/Dialect/Linalg/transform-op-decompose.mlir @@ -56,6 +56,132 @@ return %0: tensor<1x1x56x96xf32> } +// CHECK-LABEL: @pooling_nhwc_sum +// CHECK-SAME: %[[ARG0:.+]]: tensor, +// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xf32> +// CHECK-SAME: %[[ARG2:.+]]: tensor +func.func @pooling_nhwc_sum(%input: tensor, %filter: tensor<1x?xf32>, %init: tensor) -> tensor { + // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]] + // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]] + // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]] + // CHECK: %[[SLICERES:.+]] = linalg.pooling_nwc_sum + // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]] + %0 = linalg.pooling_nhwc_sum {dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor<1x?xf32>) + outs (%init: tensor) -> tensor + // CHECK: return %[[RES]] + return %0 : tensor +} + +// CHECK-LABEL: @pooling_nchw_sum +// CHECK-SAME: (%[[ARG0:[0-9a-z]+]]: tensor, +// CHECK-SAME: %[[ARG1:[0-9a-z]+]]: tensor<1x?xf32>, +// CHECK-SAME: %[[ARG2:[0-9a-z]+]]: tensor) +func.func @pooling_nchw_sum(%input: tensor, %filter: tensor<1x?xf32>, %init: tensor) -> tensor { + // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]] + // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]] + // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]] + // CHECK: %[[SLICERES:.+]] = linalg.pooling_ncw_sum + // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]] + %0 = linalg.pooling_nchw_sum {dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor<1x?xf32>) + outs (%init: tensor) -> tensor + // CHECK: return %[[RES]] + return %0 : tensor +} + +// CHECK-LABEL: @pooling_nhwc_max +// CHECK-SAME: %[[ARG0:.+]]: tensor, +// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xf32> +// CHECK-SAME: %[[ARG2:.+]]: tensor +func.func @pooling_nhwc_max(%input: tensor, %filter: tensor<1x?xf32>, %init: tensor) -> tensor { + // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]] + // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]] + // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]] + // CHECK: %[[SLICERES:.+]] = linalg.pooling_nwc_max + // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]] + %0 = linalg.pooling_nhwc_max {dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor<1x?xf32>) + outs (%init: tensor) -> tensor + // CHECK: return %[[RES]] + return %0 : tensor +} + +// CHECK-LABEL: @pooling_nhwc_max_unsigned +// CHECK-SAME: %[[ARG0:.+]]: tensor, +// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xf32> +// CHECK-SAME: %[[ARG2:.+]]: tensor +func.func @pooling_nhwc_max_unsigned(%input: tensor, %filter: tensor<1x?xf32>, %init: tensor) -> tensor { + // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]] + // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]] + // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]] + // CHECK: %[[SLICERES:.+]] = linalg.pooling_nwc_max_unsigned + // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]] + %0 = linalg.pooling_nhwc_max_unsigned {dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor<1x?xf32>) + outs (%init: tensor) -> tensor + // CHECK: return %[[RES]] + return %0 : tensor +} + +// CHECK-LABEL: @pooling_nhwc_min +// CHECK-SAME: %[[ARG0:.+]]: tensor, +// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xf32> +// CHECK-SAME: %[[ARG2:.+]]: tensor +func.func @pooling_nhwc_min(%input: tensor, %filter: tensor<1x?xf32>, %init: tensor) -> tensor { + // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]] + // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]] + // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]] + // CHECK: %[[SLICERES:.+]] = linalg.pooling_nwc_min + // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]] + %0 = linalg.pooling_nhwc_min {dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor<1x?xf32>) + outs (%init: tensor) -> tensor + // CHECK: return %[[RES]] + return %0 : tensor +} + +// CHECK-LABEL: @pooling_nhwc_min_unsigned +// CHECK-SAME: %[[ARG0:.+]]: tensor, +// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xf32> +// CHECK-SAME: %[[ARG2:.+]]: tensor +func.func @pooling_nhwc_min_unsigned(%input: tensor, %filter: tensor<1x?xf32>, %init: tensor) -> tensor { + // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]] + // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]] + // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]] + // CHECK: %[[SLICERES:.+]] = linalg.pooling_nwc_min_unsigned + // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]] + %0 = linalg.pooling_nhwc_min_unsigned {dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor<1x?xf32>) + outs (%init: tensor) -> tensor + // CHECK: return %[[RES]] + return %0 : tensor +} + +// CHECK-LABEL: @pooling_nchw_max +// CHECK-SAME: (%[[ARG0:[0-9a-z]+]]: tensor, +// CHECK-SAME: %[[ARG1:[0-9a-z]+]]: tensor<1x?xf32>, +// CHECK-SAME: %[[ARG2:[0-9a-z]+]]: tensor) +func.func @pooling_nchw_max(%input: tensor, %filter: tensor<1x?xf32>, %init: tensor) -> tensor { + // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]] + // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]] + // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]] + // CHECK: %[[SLICERES:.+]] = linalg.pooling_ncw_max + // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]] + %0 = linalg.pooling_nchw_max {dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor<1x?xf32>) + outs (%init: tensor) -> tensor + // CHECK: return %[[RES]] + return %0 : tensor +} + transform.sequence failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match interface{LinalgOp} in %arg1 Index: mlir/test/Dialect/Linalg/vectorize-convolution.mlir =================================================================== --- mlir/test/Dialect/Linalg/vectorize-convolution.mlir +++ mlir/test/Dialect/Linalg/vectorize-convolution.mlir @@ -571,3 +571,282 @@ // CHECK: %[[CONT:.*]] = vector.contract // {{.*}} %[[V_INPUT_R]], %[[V_FILTER_1]], %[[V_OUTPUT_R]] : vector<1x2x3xf16>, vector<3x2xf16> into vector<1x2x2xf32> // CHECK: vector.transfer_write %[[CONT]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]] + + + +func.func @pooling_nwc_sum_memref_1_2_1_3(%input: memref<4x4x3xf32>, %filter: memref<1xf32>, %output: memref<4x2x3xf32>) { + linalg.pooling_nwc_sum + {dilations = dense<1> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>} + ins(%input, %filter : memref<4x4x3xf32>, memref<1xf32>) + outs(%output : memref<4x2x3xf32>) + return +} + +// CHECK: func.func @pooling_nwc_sum_memref_1_2_1_3 +// CHECK-SAME: (%[[Varg0:.+]]: memref<4x4x3xf32>, %[[Varg1:.+]]: memref<1xf32>, %[[Varg2:.+]]: memref<4x2x3xf32>) +// CHECK-DAG: %[[Vc0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[Vcst:.+]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: %[[V0:.+]] = vector.transfer_read %[[Varg0]][%[[Vc0]], %[[Vc0]], %[[Vc0]]], %[[Vcst]] {in_bounds = [true, true, true]} : memref<4x4x3xf32>, vector<4x4x3xf32> +// CHECK-DAG: %[[V1:.+]] = vector.transfer_read %[[Varg2]][%[[Vc0]], %[[Vc0]], %[[Vc0]]], %[[Vcst]] {in_bounds = [true, true, true]} : memref<4x2x3xf32>, vector<4x2x3xf32> +// CHECK-DAG: %[[V2:.+]] = vector.extract_strided_slice %[[V0]] {offsets = [0, 0, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x4x3xf32> to vector<4x1x3xf32> +// CHECK-DAG: %[[V3:.+]] = vector.extract_strided_slice %[[V0]] {offsets = [0, 3, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x4x3xf32> to vector<4x1x3xf32> +// CHECK-DAG: %[[V4:.+]] = vector.extract_strided_slice %[[V1]] {offsets = [0, 0, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x2x3xf32> to vector<4x1x3xf32> +// CHECK-DAG: %[[V5:.+]] = vector.extract_strided_slice %[[V1]] {offsets = [0, 1, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x2x3xf32> to vector<4x1x3xf32> +// CHECK-DAG: %[[V6:.+]] = arith.addf %[[V2]], %[[V4]] : vector<4x1x3xf32> +// CHECK-DAG: %[[V7:.+]] = arith.addf %[[V3]], %[[V5]] : vector<4x1x3xf32> +// CHECK-DAG: %[[V8:.+]] = vector.insert_strided_slice %[[V6]], %[[V1]] {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<4x1x3xf32> into vector<4x2x3xf32> +// CHECK-DAG: %[[V9:.+]] = vector.insert_strided_slice %[[V7]], %[[V8]] {offsets = [0, 1, 0], strides = [1, 1, 1]} : vector<4x1x3xf32> into vector<4x2x3xf32> +// CHECK-DAG: vector.transfer_write %[[V9]], %[[Varg2]][%[[Vc0]], %[[Vc0]], %[[Vc0]]] {in_bounds = [true, true, true]} : vector<4x2x3xf32>, memref<4x2x3xf32> + +// ----- + +func.func @pooling_nwc_max_memref_1_2_1_3(%input: memref<4x4x3xf32>, %filter: memref<1xf32>, %output: memref<4x2x3xf32>) { + linalg.pooling_nwc_max + {dilations = dense<1> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>} + ins(%input, %filter : memref<4x4x3xf32>, memref<1xf32>) + outs(%output : memref<4x2x3xf32>) + return +} + +// CHECK: func.func @pooling_nwc_max_memref_1_2_1_3 +// CHECK-SAME: (%[[Varg0:.+]]: memref<4x4x3xf32>, %[[Varg1:.+]]: memref<1xf32>, %[[Varg2:.+]]: memref<4x2x3xf32>) +// CHECK-DAG: %[[Vc0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[Vcst:.+]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: %[[V0:.+]] = vector.transfer_read %[[Varg0]][%[[Vc0]], %[[Vc0]], %[[Vc0]]], %[[Vcst]] {in_bounds = [true, true, true]} : memref<4x4x3xf32>, vector<4x4x3xf32> +// CHECK-DAG: %[[V1:.+]] = vector.transfer_read %[[Varg2]][%[[Vc0]], %[[Vc0]], %[[Vc0]]], %[[Vcst]] {in_bounds = [true, true, true]} : memref<4x2x3xf32>, vector<4x2x3xf32> +// CHECK-DAG: %[[V2:.+]] = vector.extract_strided_slice %[[V0]] {offsets = [0, 0, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x4x3xf32> to vector<4x1x3xf32> +// CHECK-DAG: %[[V3:.+]] = vector.extract_strided_slice %[[V0]] {offsets = [0, 3, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x4x3xf32> to vector<4x1x3xf32> +// CHECK-DAG: %[[V4:.+]] = vector.extract_strided_slice %[[V1]] {offsets = [0, 0, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x2x3xf32> to vector<4x1x3xf32> +// CHECK-DAG: %[[V5:.+]] = vector.extract_strided_slice %[[V1]] {offsets = [0, 1, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x2x3xf32> to vector<4x1x3xf32> +// CHECK-DAG: %[[V6:.+]] = arith.maxf %[[V2]], %[[V4]] : vector<4x1x3xf32> +// CHECK-DAG: %[[V7:.+]] = arith.maxf %[[V3]], %[[V5]] : vector<4x1x3xf32> +// CHECK-DAG: %[[V8:.+]] = vector.insert_strided_slice %[[V6]], %[[V1]] {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<4x1x3xf32> into vector<4x2x3xf32> +// CHECK-DAG: %[[V9:.+]] = vector.insert_strided_slice %[[V7]], %[[V8]] {offsets = [0, 1, 0], strides = [1, 1, 1]} : vector<4x1x3xf32> into vector<4x2x3xf32> +// CHECK-DAG: vector.transfer_write %[[V9]], %[[Varg2]][%[[Vc0]], %[[Vc0]], %[[Vc0]]] {in_bounds = [true, true, true]} : vector<4x2x3xf32>, memref<4x2x3xf32> + +// ----- + +// The i8i8i32 case is similar to f32 case, so checking one case is enough for +// test coverage. +func.func @pooling_nwc_sum_i8i8i32_memref_1_2_1_3(%input: memref<4x4x3xi8>, %filter: memref<1xi8>, %output: memref<4x2x3xi32>) { + linalg.pooling_nwc_sum + {dilations = dense<1> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>} + ins(%input, %filter : memref<4x4x3xi8>, memref<1xi8>) + outs(%output : memref<4x2x3xi32>) + return +} + +// CHECK: func.func @pooling_nwc_sum_i8i8i32_memref_1_2_1_3 +// CHECK-SAME: (%[[Varg0:.+]]: memref<4x4x3xi8>, %[[Varg1:.+]]: memref<1xi8>, %[[Varg2:.+]]: memref<4x2x3xi32>) +// CHECK-DAG: %[[Vc0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[Vc0_i8:.+]] = arith.constant 0 : i8 +// CHECK-DAG: %[[Vc0_i32:.+]] = arith.constant 0 : i32 +// CHECK-DAG: %[[V0:.+]] = vector.transfer_read %[[Varg0]][%[[Vc0]], %[[Vc0]], %[[Vc0]]], %[[Vc0_i8]] {in_bounds = [true, true, true]} : memref<4x4x3xi8>, vector<4x4x3xi8> +// CHECK-DAG: %[[V1:.+]] = vector.transfer_read %[[Varg2]][%[[Vc0]], %[[Vc0]], %[[Vc0]]], %[[Vc0_i32]] {in_bounds = [true, true, true]} : memref<4x2x3xi32>, vector<4x2x3xi32> +// CHECK-DAG: %[[V2:.+]] = vector.extract_strided_slice %[[V0]] {offsets = [0, 0, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x4x3xi8> to vector<4x1x3xi8> +// CHECK-DAG: %[[V3:.+]] = vector.extract_strided_slice %[[V0]] {offsets = [0, 3, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x4x3xi8> to vector<4x1x3xi8> +// CHECK-DAG: %[[V4:.+]] = vector.extract_strided_slice %[[V1]] {offsets = [0, 0, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x2x3xi32> to vector<4x1x3xi32> +// CHECK-DAG: %[[V5:.+]] = vector.extract_strided_slice %[[V1]] {offsets = [0, 1, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x2x3xi32> to vector<4x1x3xi32> +// CHECK-DAG: %[[V6:.+]] = arith.extsi %[[V2]] : vector<4x1x3xi8> to vector<4x1x3xi32> +// CHECK-DAG: %[[V7:.+]] = arith.addi %[[V6]], %[[V4]] : vector<4x1x3xi32> +// CHECK-DAG: %[[V8:.+]] = arith.extsi %[[V3]] : vector<4x1x3xi8> to vector<4x1x3xi32> +// CHECK-DAG: %[[V9:.+]] = arith.addi %[[V8]], %[[V5]] : vector<4x1x3xi32> +// CHECK-DAG: %[[V10:.+]] = vector.insert_strided_slice %[[V7]], %[[V1]] {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<4x1x3xi32> into vector<4x2x3xi32> +// CHECK-DAG: %[[V11:.+]] = vector.insert_strided_slice %[[V9]], %[[V10]] {offsets = [0, 1, 0], strides = [1, 1, 1]} : vector<4x1x3xi32> into vector<4x2x3xi32> +// CHECK-DAG: vector.transfer_write %[[V11]], %[[Varg2]][%[[Vc0]], %[[Vc0]], %[[Vc0]]] {in_bounds = [true, true, true]} : vector<4x2x3xi32>, memref<4x2x3xi32> +// CHECK-DAG: return + +// ----- + +// The i8i8i32 case is similar to f32 case, so checking one case is enough for +// test coverage. +func.func @pooling_nwc_max_i8i8i32_memref_1_2_1_3(%input: memref<4x4x3xi8>, %filter: memref<1xi8>, %output: memref<4x2x3xi32>) { + linalg.pooling_nwc_max + {dilations = dense<1> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>} + ins(%input, %filter : memref<4x4x3xi8>, memref<1xi8>) + outs(%output : memref<4x2x3xi32>) + return +} + +// CHECK: func.func @pooling_nwc_max_i8i8i32_memref_1_2_1_3 +// CHECK-SAME: (%[[Varg0:.+]]: memref<4x4x3xi8>, %[[Varg1:.+]]: memref<1xi8>, %[[Varg2:.+]]: memref<4x2x3xi32>) +// CHECK-DAG: %[[Vc0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[Vc0_i8:.+]] = arith.constant 0 : i8 +// CHECK-DAG: %[[Vc0_i32:.+]] = arith.constant 0 : i32 +// CHECK-DAG: %[[V0:.+]] = vector.transfer_read %[[Varg0]][%[[Vc0]], %[[Vc0]], %[[Vc0]]], %[[Vc0_i8]] {in_bounds = [true, true, true]} : memref<4x4x3xi8>, vector<4x4x3xi8> +// CHECK-DAG: %[[V1:.+]] = vector.transfer_read %[[Varg2]][%[[Vc0]], %[[Vc0]], %[[Vc0]]], %[[Vc0_i32]] {in_bounds = [true, true, true]} : memref<4x2x3xi32>, vector<4x2x3xi32> +// CHECK-DAG: %[[V2:.+]] = vector.extract_strided_slice %[[V0]] {offsets = [0, 0, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x4x3xi8> to vector<4x1x3xi8> +// CHECK-DAG: %[[V3:.+]] = vector.extract_strided_slice %[[V0]] {offsets = [0, 3, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x4x3xi8> to vector<4x1x3xi8> +// CHECK-DAG: %[[V4:.+]] = vector.extract_strided_slice %[[V1]] {offsets = [0, 0, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x2x3xi32> to vector<4x1x3xi32> +// CHECK-DAG: %[[V5:.+]] = vector.extract_strided_slice %[[V1]] {offsets = [0, 1, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x2x3xi32> to vector<4x1x3xi32> +// CHECK-DAG: %[[V6:.+]] = arith.extsi %[[V2]] : vector<4x1x3xi8> to vector<4x1x3xi32> +// CHECK-DAG: %[[V7:.+]] = arith.maxsi %[[V6]], %[[V4]] : vector<4x1x3xi32> +// CHECK-DAG: %[[V8:.+]] = arith.extsi %[[V3]] : vector<4x1x3xi8> to vector<4x1x3xi32> +// CHECK-DAG: %[[V9:.+]] = arith.maxsi %[[V8]], %[[V5]] : vector<4x1x3xi32> +// CHECK-DAG: %[[V10:.+]] = vector.insert_strided_slice %[[V7]], %[[V1]] {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<4x1x3xi32> into vector<4x2x3xi32> +// CHECK-DAG: %[[V11:.+]] = vector.insert_strided_slice %[[V9]], %[[V10]] {offsets = [0, 1, 0], strides = [1, 1, 1]} : vector<4x1x3xi32> into vector<4x2x3xi32> +// CHECK-DAG: vector.transfer_write %[[V11]], %[[Varg2]][%[[Vc0]], %[[Vc0]], %[[Vc0]]] {in_bounds = [true, true, true]} : vector<4x2x3xi32>, memref<4x2x3xi32> +// CHECK-DAG: return + +// ----- + +func.func @pooling_nwc_sum_memref_2_2_2_3(%input: memref<4x6x3xf32>, %filter: memref<2xf32>, %output: memref<4x2x3xf32>) { + linalg.pooling_nwc_sum + {dilations = dense<2> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>} + ins(%input, %filter : memref<4x6x3xf32>, memref<2xf32>) + outs(%output : memref<4x2x3xf32>) + return +} + +// CHECK: func.func @pooling_nwc_sum_memref_2_2_2_3 +// CHECK-SAME: (%[[Varg0:.+]]: memref<4x6x3xf32>, %[[Varg1:.+]]: memref<2xf32>, %[[Varg2:.+]]: memref<4x2x3xf32>) +// CHECK-DAG: %[[Vc0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[Vcst:.+]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: %[[V0:.+]] = vector.transfer_read %[[Varg0]][%[[Vc0]], %[[Vc0]], %[[Vc0]]], %[[Vcst]] {in_bounds = [true, true, true]} : memref<4x6x3xf32>, vector<4x6x3xf32> +// CHECK-DAG: %[[V1:.+]] = vector.transfer_read %[[Varg2]][%[[Vc0]], %[[Vc0]], %[[Vc0]]], %[[Vcst]] {in_bounds = [true, true, true]} : memref<4x2x3xf32>, vector<4x2x3xf32> +// CHECK-DAG: %[[V2:.+]] = vector.extract_strided_slice %[[V0]] {offsets = [0, 0, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x6x3xf32> to vector<4x1x3xf32> +// CHECK-DAG: %[[V3:.+]] = vector.extract_strided_slice %[[V0]] {offsets = [0, 3, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x6x3xf32> to vector<4x1x3xf32> +// CHECK-DAG: %[[V4:.+]] = vector.extract_strided_slice %[[V0]] {offsets = [0, 2, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x6x3xf32> to vector<4x1x3xf32> +// CHECK-DAG: %[[V5:.+]] = vector.extract_strided_slice %[[V0]] {offsets = [0, 5, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x6x3xf32> to vector<4x1x3xf32> +// CHECK-DAG: %[[V6:.+]] = vector.extract_strided_slice %[[V1]] {offsets = [0, 0, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x2x3xf32> to vector<4x1x3xf32> +// CHECK-DAG: %[[V7:.+]] = vector.extract_strided_slice %[[V1]] {offsets = [0, 1, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x2x3xf32> to vector<4x1x3xf32> +// CHECK-DAG: %[[V8:.+]] = arith.addf %[[V2]], %[[V6]] : vector<4x1x3xf32> +// CHECK-DAG: %[[V9:.+]] = arith.addf %[[V3]], %[[V7]] : vector<4x1x3xf32> +// CHECK-DAG: %[[V10:.+]] = arith.addf %[[V4]], %[[V8]] : vector<4x1x3xf32> +// CHECK-DAG: %[[V11:.+]] = arith.addf %[[V5]], %[[V9]] : vector<4x1x3xf32> +// CHECK-DAG: %[[V12:.+]] = vector.insert_strided_slice %[[V10]], %[[V1]] {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<4x1x3xf32> into vector<4x2x3xf32> +// CHECK-DAG: %[[V13:.+]] = vector.insert_strided_slice %[[V11]], %[[V12]] {offsets = [0, 1, 0], strides = [1, 1, 1]} : vector<4x1x3xf32> into vector<4x2x3xf32> +// CHECK-DAG: vector.transfer_write %[[V13:.+]], %[[Varg2]][%[[Vc0]], %[[Vc0]], %[[Vc0]]] {in_bounds = [true, true, true]} : vector<4x2x3xf32>, memref<4x2x3xf32> + + +// ----- + +func.func @pooling_ncw_sum_memref_1_2_1_3(%input: memref<4x3x4xf32>, %filter: memref<1xf32>, %output: memref<4x3x2xf32>) { + linalg.pooling_ncw_sum + {dilations = dense<1> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>} + ins(%input, %filter : memref<4x3x4xf32>, memref<1xf32>) + outs(%output : memref<4x3x2xf32>) + return +} + +// CHECK: func.func @pooling_ncw_sum_memref_1_2_1_3 +// CHECK-SAME: (%[[Varg0:.+]]: memref<4x3x4xf32>, %[[Varg1:.+]]: memref<1xf32>, %[[Varg2:.+]]: memref<4x3x2xf32>) +// CHECK-DAG: %[[Vc0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[Vcst:.+]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: %[[V0:.+]] = vector.transfer_read %[[Varg0]][%[[Vc0]], %[[Vc0]], %[[Vc0]]], %[[Vcst]] {in_bounds = [true, true, true]} : memref<4x3x4xf32>, vector<4x3x4xf32> +// CHECK-DAG: %[[V1:.+]] = vector.transfer_read %[[Varg2]][%[[Vc0]], %[[Vc0]], %[[Vc0]]], %[[Vcst]] {in_bounds = [true, true, true]} : memref<4x3x2xf32>, vector<4x3x2xf32> +// CHECK-DAG: %[[V2:.+]] = vector.transpose %[[V0]], [0, 2, 1] : vector<4x3x4xf32> to vector<4x4x3xf32> +// CHECK-DAG: %[[V3:.+]] = vector.transpose %[[V1]], [0, 2, 1] : vector<4x3x2xf32> to vector<4x2x3xf32> +// CHECK-DAG: %[[V4:.+]] = vector.extract_strided_slice %[[V2]] {offsets = [0, 0, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x4x3xf32> to vector<4x1x3xf32> +// CHECK-DAG: %[[V5:.+]] = vector.extract_strided_slice %[[V2]] {offsets = [0, 3, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x4x3xf32> to vector<4x1x3xf32> +// CHECK-DAG: %[[V6:.+]] = vector.extract_strided_slice %[[V3]] {offsets = [0, 0, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x2x3xf32> to vector<4x1x3xf32> +// CHECK-DAG: %[[V7:.+]] = vector.extract_strided_slice %[[V3]] {offsets = [0, 1, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x2x3xf32> to vector<4x1x3xf32> +// CHECK-DAG: %[[V8:.+]] = arith.addf %[[V4]], %[[V6]] : vector<4x1x3xf32> +// CHECK-DAG: %[[V9:.+]] = arith.addf %[[V5]], %[[V7]] : vector<4x1x3xf32> +// CHECK-DAG: %[[V10:.+]] = vector.insert_strided_slice %[[V8]], %[[V3]] {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<4x1x3xf32> into vector<4x2x3xf32> +// CHECK-DAG: %[[V11:.+]] = vector.insert_strided_slice %[[V9]], %[[V10]] {offsets = [0, 1, 0], strides = [1, 1, 1]} : vector<4x1x3xf32> into vector<4x2x3xf32> +// CHECK-DAG: %[[V12:.+]] = vector.transpose %[[V11]], [0, 2, 1] : vector<4x2x3xf32> to vector<4x3x2xf32> +// CHECK-DAG: vector.transfer_write %[[V12:.+]], %[[Varg2]][%[[Vc0]], %[[Vc0]], %[[Vc0]]] {in_bounds = [true, true, true]} : vector<4x3x2xf32>, memref<4x3x2xf32> + + +// ----- + +func.func @pooling_nwc_sum_mixed_type_memref_1_2_1_1(%input: memref<1x2x3xf16>, %filter: memref<1xf16>, %output: memref<1x2x3xf32>) { + linalg.pooling_nwc_sum + {dilations = dense<1> : vector<1xi64>, strides = dense<1> : vector<1xi64>} + ins(%input, %filter : memref<1x2x3xf16>, memref<1xf16>) + outs(%output : memref<1x2x3xf32>) + return +} + +// CHECK: func.func @pooling_nwc_sum_mixed_type_memref_1_2_1_1 +// CHECK-SAME: (%[[Varg0:.+]]: memref<1x2x3xf16>, %[[Varg1:.+]]: memref<1xf16>, %[[Varg2:.+]]: memref<1x2x3xf32>) +// CHECK-DAG: %[[Vc0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[Vcst:.+]] = arith.constant 0.000000e+00 : f16 +// CHECK-DAG: %[[Vcst_0:.+]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: %[[V0:.+]] = vector.transfer_read %[[Varg0]][%[[Vc0]], %[[Vc0]], %[[Vc0]]], %[[Vcst]] {in_bounds = [true, true, true]} : memref<1x2x3xf16>, vector<1x2x3xf16> +// CHECK-DAG: %[[V1:.+]] = vector.transfer_read %[[Varg2]][%[[Vc0]], %[[Vc0]], %[[Vc0]]], %[[Vcst_0]] {in_bounds = [true, true, true]} : memref<1x2x3xf32>, vector<1x2x3xf32> +// CHECK-DAG: %[[V2:.+]] = arith.extf %[[V0]] : vector<1x2x3xf16> to vector<1x2x3xf32> +// CHECK-DAG: %[[V3:.+]] = arith.addf %[[V2]], %[[V1]] : vector<1x2x3xf32> +// CHECK-DAG: vector.transfer_write %[[V3:.+]], %[[Varg2]][%[[Vc0]], %[[Vc0]], %[[Vc0]]] {in_bounds = [true, true, true]} : vector<1x2x3xf32>, memref<1x2x3xf32> + +// ----- + +func.func @pooling_nwc_sum_memref_2_2_2_1(%input: memref<4x4x3xf32>, %filter: memref<2xf32>, %output: memref<4x2x3xf32>) { + linalg.pooling_nwc_sum + {dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} + ins(%input, %filter : memref<4x4x3xf32>, memref<2xf32>) + outs(%output : memref<4x2x3xf32>) + return +} + +// CHECK: func.func @pooling_nwc_sum_memref_2_2_2_1 +// CHECK-SAME: (%[[Varg0:.+]]: memref<4x4x3xf32>, %[[Varg1:.+]]: memref<2xf32>, %[[Varg2:.+]]: memref<4x2x3xf32>) +// CHECK-DAG: %[[Vc0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[Vcst:.+]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: %[[V0:.+]] = vector.transfer_read %[[Varg0]][%[[Vc0]], %[[Vc0]], %[[Vc0]]], %[[Vcst]] {in_bounds = [true, true, true]} : memref<4x4x3xf32>, vector<4x4x3xf32> +// CHECK-DAG: %[[V1:.+]] = vector.transfer_read %[[Varg2]][%[[Vc0]], %[[Vc0]], %[[Vc0]]], %[[Vcst]] {in_bounds = [true, true, true]} : memref<4x2x3xf32>, vector<4x2x3xf32> +// CHECK-DAG: %[[V2:.+]] = vector.extract_strided_slice %[[V0]] {offsets = [0, 0, 0], sizes = [4, 2, 3], strides = [1, 1, 1]} : vector<4x4x3xf32> to vector<4x2x3xf32> +// CHECK-DAG: %[[V3:.+]] = vector.extract_strided_slice %[[V0]] {offsets = [0, 2, 0], sizes = [4, 2, 3], strides = [1, 1, 1]} : vector<4x4x3xf32> to vector<4x2x3xf32> +// CHECK-DAG: %[[V4:.+]] = arith.addf %[[V2]], %[[V1]] : vector<4x2x3xf32> +// CHECK-DAG: %[[V5:.+]] = arith.addf %[[V3]], %[[V4]] : vector<4x2x3xf32> +// CHECK-DAG: vector.transfer_write %[[V5:.+]], %[[Varg2]][%[[Vc0]], %[[Vc0]], %[[Vc0]]] {in_bounds = [true, true, true]} : vector<4x2x3xf32>, memref<4x2x3xf32> + + +// ----- + +func.func @pooling_ncw_sum_memref_2_2_2_3(%input: memref<4x3x6xf32>, %filter: memref<2xf32>, %output: memref<4x3x2xf32>) { + linalg.pooling_ncw_sum + {dilations = dense<2> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>} + ins(%input, %filter : memref<4x3x6xf32>, memref<2xf32>) + outs(%output : memref<4x3x2xf32>) + return +} + +// CHECK: func.func @pooling_ncw_sum_memref_2_2_2_3 +// CHECK-SAME: (%[[Varg0:.+]]: memref<4x3x6xf32>, %[[Varg1:.+]]: memref<2xf32>, %[[Varg2:.+]]: memref<4x3x2xf32>) +// CHECK-DAG: %[[Vc0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[Vcst:.+]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: %[[V0:.+]] = vector.transfer_read %[[Varg0]][%[[Vc0]], %[[Vc0]], %[[Vc0]]], %[[Vcst]] {in_bounds = [true, true, true]} : memref<4x3x6xf32>, vector<4x3x6xf32> +// CHECK-DAG: %[[V1:.+]] = vector.transfer_read %[[Varg2]][%[[Vc0]], %[[Vc0]], %[[Vc0]]], %[[Vcst]] {in_bounds = [true, true, true]} : memref<4x3x2xf32>, vector<4x3x2xf32> +// CHECK-DAG: %[[V2:.+]] = vector.transpose %[[V0]], [0, 2, 1] : vector<4x3x6xf32> to vector<4x6x3xf32> +// CHECK-DAG: %[[V3:.+]] = vector.transpose %[[V1]], [0, 2, 1] : vector<4x3x2xf32> to vector<4x2x3xf32> +// CHECK-DAG: %[[V4:.+]] = vector.extract_strided_slice %[[V2]] {offsets = [0, 0, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x6x3xf32> to vector<4x1x3xf32> +// CHECK-DAG: %[[V5:.+]] = vector.extract_strided_slice %[[V2]] {offsets = [0, 3, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x6x3xf32> to vector<4x1x3xf32> +// CHECK-DAG: %[[V6:.+]] = vector.extract_strided_slice %[[V2]] {offsets = [0, 2, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x6x3xf32> to vector<4x1x3xf32> +// CHECK-DAG: %[[V7:.+]] = vector.extract_strided_slice %[[V2]] {offsets = [0, 5, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x6x3xf32> to vector<4x1x3xf32> +// CHECK-DAG: %[[V8:.+]] = vector.extract_strided_slice %[[V3]] {offsets = [0, 0, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x2x3xf32> to vector<4x1x3xf32> +// CHECK-DAG: %[[V9:.+]] = vector.extract_strided_slice %[[V3]] {offsets = [0, 1, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x2x3xf32> to vector<4x1x3xf32> +// CHECK-DAG: %[[V10:.+]] = arith.addf %[[V4]], %[[V8]] : vector<4x1x3xf32> +// CHECK-DAG: %[[V11:.+]] = arith.addf %[[V5]], %[[V9]] : vector<4x1x3xf32> +// CHECK-DAG: %[[V12:.+]] = arith.addf %[[V6]], %[[V10]] : vector<4x1x3xf32> +// CHECK-DAG: %[[V13:.+]] = arith.addf %[[V7]], %[[V11]] : vector<4x1x3xf32> +// CHECK-DAG: %[[V14:.+]] = vector.insert_strided_slice %[[V12]], %[[V3]] {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<4x1x3xf32> into vector<4x2x3xf32> +// CHECK-DAG: %[[V15:.+]] = vector.insert_strided_slice %[[V13]], %[[V14]] {offsets = [0, 1, 0], strides = [1, 1, 1]} : vector<4x1x3xf32> into vector<4x2x3xf32> +// CHECK-DAG: %[[V16:.+]] = vector.transpose %[[V15]], [0, 2, 1] : vector<4x2x3xf32> to vector<4x3x2xf32> +// CHECK-DAG: vector.transfer_write %[[V16:.+]], %[[Varg2]][%[[Vc0]], %[[Vc0]], %[[Vc0]]] {in_bounds = [true, true, true]} : vector<4x3x2xf32>, memref<4x3x2xf32> + +// ----- + +func.func @pooling_ncw_sum_memref_2_3_2_1(%input: memref<4x2x5xf32>, %filter: memref<2xf32>, %output: memref<4x2x3xf32>) { + linalg.pooling_ncw_sum + {dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} + ins(%input, %filter : memref<4x2x5xf32>, memref<2xf32>) + outs(%output : memref<4x2x3xf32>) + return +} + +// CHECK: func.func @pooling_ncw_sum_memref_2_3_2_1 +// CHECK-SAME: (%[[Varg0:.+]]: memref<4x2x5xf32>, %[[Varg1:.+]]: memref<2xf32>, %[[Varg2:.+]]: memref<4x2x3xf32>) +// CHECK-DAG: %[[Vc0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[Vcst:.+]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: %[[V0:.+]] = vector.transfer_read %[[Varg0]][%[[Vc0]], %[[Vc0]], %[[Vc0]]], %[[Vcst]] {in_bounds = [true, true, true]} : memref<4x2x5xf32>, vector<4x2x5xf32> +// CHECK-DAG: %[[V1:.+]] = vector.transfer_read %[[Varg2]][%[[Vc0]], %[[Vc0]], %[[Vc0]]], %[[Vcst]] {in_bounds = [true, true, true]} : memref<4x2x3xf32>, vector<4x2x3xf32> +// CHECK-DAG: %[[V2:.+]] = vector.transpose %[[V0]], [0, 2, 1] : vector<4x2x5xf32> to vector<4x5x2xf32> +// CHECK-DAG: %[[V3:.+]] = vector.transpose %[[V1]], [0, 2, 1] : vector<4x2x3xf32> to vector<4x3x2xf32> +// CHECK-DAG: %[[V4:.+]] = vector.extract_strided_slice %[[V2]] {offsets = [0, 0, 0], sizes = [4, 3, 2], strides = [1, 1, 1]} : vector<4x5x2xf32> to vector<4x3x2xf32> +// CHECK-DAG: %[[V5:.+]] = vector.extract_strided_slice %[[V2]] {offsets = [0, 2, 0], sizes = [4, 3, 2], strides = [1, 1, 1]} : vector<4x5x2xf32> to vector<4x3x2xf32> +// CHECK-DAG: %[[V6:.+]] = arith.addf %[[V4]], %[[V3]] : vector<4x3x2xf32> +// CHECK-DAG: %[[V7:.+]] = arith.addf %[[V5]], %[[V6]] : vector<4x3x2xf32> +// CHECK-DAG: %[[V8:.+]] = vector.transpose %[[V7]], [0, 2, 1] : vector<4x3x2xf32> to vector<4x2x3xf32> +// CHECK-DAG: vector.transfer_write %[[V8:.+]], %[[Varg2]][%[[Vc0]], %[[Vc0]], %[[Vc0]]] {in_bounds = [true, true, true]} : vector<4x2x3xf32>, memref<4x2x3xf32>