diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -96,6 +96,57 @@ results.push_back(*depthwise); return DiagnosedSilenceableFailure::success(); } + FailureOr poolingNhwcSum = tryApply< + DownscaleSizeOneWindowed2DConvolution>( + target); + if (succeeded(poolingNhwcSum)) { + results.push_back(*poolingNhwcSum); + return DiagnosedSilenceableFailure::success(); + } + FailureOr poolingNchwSum = tryApply< + DownscaleSizeOneWindowed2DConvolution>( + target); + if (succeeded(poolingNchwSum)) { + results.push_back(*poolingNchwSum); + return DiagnosedSilenceableFailure::success(); + } + FailureOr poolingNhwcMax = tryApply< + DownscaleSizeOneWindowed2DConvolution>( + target); + if (succeeded(poolingNhwcMax)) { + results.push_back(*poolingNhwcMax); + return DiagnosedSilenceableFailure::success(); + } + FailureOr poolingNhwcMaxUnsigned = + tryApply>( + target); + if (succeeded(poolingNhwcMaxUnsigned)) { + results.push_back(*poolingNhwcMaxUnsigned); + return DiagnosedSilenceableFailure::success(); + } + FailureOr poolingNhwcMin = tryApply< + DownscaleSizeOneWindowed2DConvolution>( + target); + if (succeeded(poolingNhwcMin)) { + results.push_back(*poolingNhwcMin); + return DiagnosedSilenceableFailure::success(); + } + FailureOr poolingNhwcMinUnsigned = + tryApply>( + target); + if (succeeded(poolingNhwcMinUnsigned)) { + results.push_back(*poolingNhwcMinUnsigned); + return DiagnosedSilenceableFailure::success(); + } + FailureOr poolingNchwMax = tryApply< + DownscaleSizeOneWindowed2DConvolution>( + target); + if (succeeded(poolingNchwMax)) { + results.push_back(*poolingNchwMax); + return DiagnosedSilenceableFailure::success(); + } results.assign(1, nullptr); return emitDefaultSilenceableFailure(target); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -628,8 +628,50 @@ ohIndex = 2; owIndex = 3; }) + .Case([&](linalg::PoolingNhwcSumOp op) { + khIndex = 0; + kwIndex = 1; + ohIndex = 1; + owIndex = 2; + }) + .Case([&](linalg::PoolingNchwSumOp op) { + khIndex = 0; + kwIndex = 1; + ohIndex = 2; + owIndex = 3; + }) + .Case([&](linalg::PoolingNhwcMaxOp op) { + khIndex = 0; + kwIndex = 1; + ohIndex = 1; + owIndex = 2; + }) + .Case([&](linalg::PoolingNhwcMaxUnsignedOp op) { + khIndex = 0; + kwIndex = 1; + ohIndex = 1; + owIndex = 2; + }) + .Case([&](linalg::PoolingNhwcMinOp op) { + khIndex = 0; + kwIndex = 1; + ohIndex = 1; + owIndex = 2; + }) + .Case([&](linalg::PoolingNhwcMinUnsignedOp op) { + khIndex = 0; + kwIndex = 1; + ohIndex = 1; + owIndex = 2; + }) + .Case([&](linalg::PoolingNchwMaxOp op) { + khIndex = 0; + kwIndex = 1; + ohIndex = 2; + owIndex = 3; + }) .Default([&](Operation *op) { - llvm_unreachable("unexpected conv2d operation."); + llvm_unreachable("unexpected conv2d/pool2d operation."); }); // Only handle the case where at least one of the window dimensions is @@ -688,6 +730,20 @@ Conv1DNwcWcfOp>; template struct linalg::DownscaleSizeOneWindowed2DConvolution; +template struct linalg::DownscaleSizeOneWindowed2DConvolution; +template struct linalg::DownscaleSizeOneWindowed2DConvolution; +template struct linalg::DownscaleSizeOneWindowed2DConvolution; +template struct linalg::DownscaleSizeOneWindowed2DConvolution< + PoolingNhwcMaxUnsignedOp, PoolingNwcMaxUnsignedOp>; +template struct linalg::DownscaleSizeOneWindowed2DConvolution; +template struct linalg::DownscaleSizeOneWindowed2DConvolution< + PoolingNhwcMinUnsignedOp, PoolingNwcMinUnsignedOp>; +template struct linalg::DownscaleSizeOneWindowed2DConvolution; FailureOr DownscaleDepthwiseConv2DNhwcHwcOp::returningMatchAndRewrite( @@ -765,4 +821,15 @@ Conv1DNcwFcwOp>, DownscaleDepthwiseConv2DNhwcHwcOp>(patterns.getContext(), benefit); + patterns.add< + DownscaleSizeOneWindowed2DConvolution, + DownscaleSizeOneWindowed2DConvolution, + DownscaleSizeOneWindowed2DConvolution, + DownscaleSizeOneWindowed2DConvolution, + DownscaleSizeOneWindowed2DConvolution, + DownscaleSizeOneWindowed2DConvolution, + DownscaleSizeOneWindowed2DConvolution>( + patterns.getContext(), benefit); } diff --git a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir --- a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir @@ -56,6 +56,132 @@ return %0: tensor<1x1x56x96xf32> } +// CHECK-LABEL: @pooling_nhwc_sum +// CHECK-SAME: %[[ARG0:.+]]: tensor, +// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xf32> +// CHECK-SAME: %[[ARG2:.+]]: tensor +func.func @pooling_nhwc_sum(%input: tensor, %filter: tensor<1x?xf32>, %init: tensor) -> tensor { + // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]] + // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]] + // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]] + // CHECK: %[[SLICERES:.+]] = linalg.pooling_nwc_sum + // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]] + %0 = linalg.pooling_nhwc_sum {dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor<1x?xf32>) + outs (%init: tensor) -> tensor + // CHECK: return %[[RES]] + return %0 : tensor +} + +// CHECK-LABEL: @pooling_nchw_sum +// CHECK-SAME: (%[[ARG0:[0-9a-z]+]]: tensor, +// CHECK-SAME: %[[ARG1:[0-9a-z]+]]: tensor<1x?xf32>, +// CHECK-SAME: %[[ARG2:[0-9a-z]+]]: tensor) +func.func @pooling_nchw_sum(%input: tensor, %filter: tensor<1x?xf32>, %init: tensor) -> tensor { + // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]] + // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]] + // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]] + // CHECK: %[[SLICERES:.+]] = linalg.pooling_ncw_sum + // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]] + %0 = linalg.pooling_nchw_sum {dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor<1x?xf32>) + outs (%init: tensor) -> tensor + // CHECK: return %[[RES]] + return %0 : tensor +} + +// CHECK-LABEL: @pooling_nhwc_max +// CHECK-SAME: %[[ARG0:.+]]: tensor, +// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xf32> +// CHECK-SAME: %[[ARG2:.+]]: tensor +func.func @pooling_nhwc_max(%input: tensor, %filter: tensor<1x?xf32>, %init: tensor) -> tensor { + // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]] + // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]] + // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]] + // CHECK: %[[SLICERES:.+]] = linalg.pooling_nwc_max + // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]] + %0 = linalg.pooling_nhwc_max {dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor<1x?xf32>) + outs (%init: tensor) -> tensor + // CHECK: return %[[RES]] + return %0 : tensor +} + +// CHECK-LABEL: @pooling_nhwc_max_unsigned +// CHECK-SAME: %[[ARG0:.+]]: tensor, +// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xf32> +// CHECK-SAME: %[[ARG2:.+]]: tensor +func.func @pooling_nhwc_max_unsigned(%input: tensor, %filter: tensor<1x?xf32>, %init: tensor) -> tensor { + // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]] + // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]] + // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]] + // CHECK: %[[SLICERES:.+]] = linalg.pooling_nwc_max_unsigned + // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]] + %0 = linalg.pooling_nhwc_max_unsigned {dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor<1x?xf32>) + outs (%init: tensor) -> tensor + // CHECK: return %[[RES]] + return %0 : tensor +} + +// CHECK-LABEL: @pooling_nhwc_min +// CHECK-SAME: %[[ARG0:.+]]: tensor, +// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xf32> +// CHECK-SAME: %[[ARG2:.+]]: tensor +func.func @pooling_nhwc_min(%input: tensor, %filter: tensor<1x?xf32>, %init: tensor) -> tensor { + // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]] + // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]] + // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]] + // CHECK: %[[SLICERES:.+]] = linalg.pooling_nwc_min + // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]] + %0 = linalg.pooling_nhwc_min {dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor<1x?xf32>) + outs (%init: tensor) -> tensor + // CHECK: return %[[RES]] + return %0 : tensor +} + +// CHECK-LABEL: @pooling_nhwc_min_unsigned +// CHECK-SAME: %[[ARG0:.+]]: tensor, +// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xf32> +// CHECK-SAME: %[[ARG2:.+]]: tensor +func.func @pooling_nhwc_min_unsigned(%input: tensor, %filter: tensor<1x?xf32>, %init: tensor) -> tensor { + // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]] + // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]] + // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]] + // CHECK: %[[SLICERES:.+]] = linalg.pooling_nwc_min_unsigned + // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]] + %0 = linalg.pooling_nhwc_min_unsigned {dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor<1x?xf32>) + outs (%init: tensor) -> tensor + // CHECK: return %[[RES]] + return %0 : tensor +} + +// CHECK-LABEL: @pooling_nchw_max +// CHECK-SAME: (%[[ARG0:[0-9a-z]+]]: tensor, +// CHECK-SAME: %[[ARG1:[0-9a-z]+]]: tensor<1x?xf32>, +// CHECK-SAME: %[[ARG2:[0-9a-z]+]]: tensor) +func.func @pooling_nchw_max(%input: tensor, %filter: tensor<1x?xf32>, %init: tensor) -> tensor { + // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]] + // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]] + // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]] + // CHECK: %[[SLICERES:.+]] = linalg.pooling_ncw_max + // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]] + %0 = linalg.pooling_nchw_max {dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor<1x?xf32>) + outs (%init: tensor) -> tensor + // CHECK: return %[[RES]] + return %0 : tensor +} + transform.sequence failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match interface{LinalgOp} in %arg1