diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -904,10 +904,83 @@ }; }; +/// Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh) +/// dimensions into 1-D depthwise convolution ops. +struct DownscaleDepthwiseConv2DNhwcHwcOp final + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcOp convOp, + PatternRewriter &rewriter) const override { + auto linalgOp = cast(*convOp); + if (linalgOp.hasBufferSemantics()) + return failure(); // To be implemented + + Value input = convOp.inputs().front(); + Value kernel = convOp.inputs().back(); + Value output = convOp.outputs().front(); + + auto inputType = input.getType().dyn_cast(); + auto kernelType = kernel.getType().dyn_cast(); + auto outputType = output.getType().dyn_cast(); + + auto kernelShape = kernelType.getShape(); + auto outputShape = outputType.getShape(); + + // Only handle the case where at least one of the window dimensions is + // of size 1. Other cases can rely on tiling to reduce to such cases. + int64_t khSize = kernelShape[0], kwSize = kernelShape[1]; + int64_t ohSize = outputShape[1], owSize = outputShape[2]; + bool removeH = (khSize == 1 && ohSize == 1); + bool removeW = (kwSize == 1 && owSize == 1); + if (!removeH && !removeW) + return failure(); + + // Get new shapes and types for all operands by removing the size-1 + // dimension. + using RTTBuilder = RankedTensorType::Builder; + auto newInputType = RTTBuilder(inputType).dropDim((removeH ? 1 : 2)); + auto newKernelType = RTTBuilder(kernelType).dropDim((removeH ? 0 : 1)); + auto newOutputType = RTTBuilder(outputType).dropDim(removeH ? 1 : 2); + + // Rank-reduce operands. + Location loc = convOp.getLoc(); + Value newInput = tensor::createCanonicalRankReducingExtractSliceOp( + rewriter, loc, input, newInputType); + Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp( + rewriter, loc, kernel, newKernelType); + Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp( + rewriter, loc, output, newOutputType); + + // Rank-reduce strides and dilations too. + // TODO: dropDim 1-liner helper. + auto strides = llvm::to_vector<4>(convOp.strides().getValues()); + strides.erase(strides.begin() + (removeH ? 0 : 1)); + auto stridesAttr = rewriter.getI64VectorAttr(strides); + + auto dilations = + llvm::to_vector<4>(convOp.dilations().getValues()); + dilations.erase(dilations.begin() + (removeH ? 0 : 1)); + auto dilationsAttr = rewriter.getI64VectorAttr(dilations); + + auto conv1DOp = rewriter.create( + loc, newOutputType, ValueRange{newInput, newKernel}, + ValueRange{newOutput}, stridesAttr, dilationsAttr); + + // Insert back. + Value inserted = tensor::createCanonicalRankReducingInsertSliceOp( + rewriter, loc, conv1DOp.getResult(0), output); + rewriter.replaceOp(convOp, inserted); + + return success(); + }; +}; + } // namespace void linalg::populateDecomposeConvolutionPatterns(RewritePatternSet &patterns, PatternBenefit benefit) { - patterns.add(patterns.getContext(), - benefit); + patterns.add(patterns.getContext(), + benefit); } diff --git a/mlir/test/Dialect/Linalg/decompose-convolution.mlir b/mlir/test/Dialect/Linalg/decompose-convolution.mlir --- a/mlir/test/Dialect/Linalg/decompose-convolution.mlir +++ b/mlir/test/Dialect/Linalg/decompose-convolution.mlir @@ -68,3 +68,27 @@ outs(%init : tensor<4x1x2x8xf32>) -> tensor<4x1x2x8xf32> return %0 : tensor<4x1x2x8xf32> } + +// ----- + +// CHECK-LABEL: func @depthwise_conv_2d_nhwc_hwc_tensor +func @depthwise_conv_2d_nhwc_hwc_tensor(%input: tensor<1x1x113x96xf32>, %filter: tensor<1x3x96xf32>, %out: tensor<1x1x56x96xf32>) -> tensor<1x1x56x96xf32> { + // CHECK: linalg.depthwise_conv_1d_nwc_wc + %0 = linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>} + ins(%input, %filter: tensor<1x1x113x96xf32>, tensor<1x3x96xf32>) + outs(%out: tensor<1x1x56x96xf32>) -> tensor<1x1x56x96xf32> + return %0: tensor<1x1x56x96xf32> +} + +// ----- + +// Do not convert convolution ops whose window dimensions are not ones. + +// CHECK-LABEL: func @depthwise_conv_2d_nhwc_hwc_tensor +func @depthwise_conv_2d_nhwc_hwc_tensor(%input: tensor<1x113x113x96xf32>, %filter: tensor<3x3x96xf32>, %out: tensor<1x56x56x96xf32>) -> tensor<1x56x56x96xf32> { + // CHECK: linalg.depthwise_conv_2d_nhwc_hwc + %0 = linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>} + ins(%input, %filter: tensor<1x113x113x96xf32>, tensor<3x3x96xf32>) + outs(%out: tensor<1x56x56x96xf32>) -> tensor<1x56x56x96xf32> + return %0: tensor<1x56x56x96xf32> +}