Index: mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h =================================================================== --- mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1041,6 +1041,19 @@ } }; +struct DownscaleConv2DOp final : public OpRewritePattern { + DownscaleConv2DOp(MLIRContext *context, PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit) {} + + FailureOr returningMatchAndRewrite(Conv2DOp convOp, + PatternRewriter &rewriter) const; + + LogicalResult matchAndRewrite(Conv2DOp convOp, + PatternRewriter &rewriter) const override { + return returningMatchAndRewrite(convOp, rewriter); + } +}; + /// /// Linalg generalization pattern. /// Index: mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp =================================================================== --- mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -266,6 +266,7 @@ DOWNSCALE_NORMAL(PoolingNhwcMinUnsignedOp, PoolingNwcMinUnsignedOp) DOWNSCALE_NORMAL(PoolingNchwMaxOp, PoolingNcwMaxOp) DOWNSCALE(DownscaleDepthwiseConv2DNhwcHwcOp) + DOWNSCALE(DownscaleConv2DOp) #undef DOWNSCALE_NORMAL #undef DOWNSCALE_CALL #undef DOWNSCALE Index: mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp =================================================================== --- mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -1362,14 +1362,71 @@ return conv1DOp; } +FailureOr +DownscaleConv2DOp::returningMatchAndRewrite(Conv2DOp convOp, + PatternRewriter &rewriter) const { + if (convOp.hasBufferSemantics()) + return failure(); // To be implemented. + + Value input = convOp.getInputs().front(); + Value kernel = convOp.getInputs().back(); + Value output = convOp.getOutputs().front(); + + auto inputType = input.getType().dyn_cast(); + auto kernelType = kernel.getType().dyn_cast(); + auto outputType = output.getType().dyn_cast(); + + auto kernelShape = kernelType.getShape(); + auto outputShape = outputType.getShape(); + + // Only handle the case where at least one of the window dimensions is + // of size 1. Other cases can rely on tiling to reduce to such cases. + int64_t khSize = kernelShape[0], kwSize = kernelShape[1]; + int64_t ohSize = outputShape[0], owSize = outputShape[1]; + bool removeH = (khSize == 1 && ohSize == 1); + bool removeW = (kwSize == 1 && owSize == 1); + if (!removeH && !removeW) + return failure(); + + // Get new shapes and types for all operands by removing the size-1 + // dimension. + using RTTBuilder = RankedTensorType::Builder; + RankedTensorType newInputType = + RTTBuilder(inputType).dropDim((removeH ? 0 : 1)); + RankedTensorType newKernelType = + RTTBuilder(kernelType).dropDim((removeH ? 0 : 1)); + RankedTensorType newOutputType = + RTTBuilder(outputType).dropDim(removeH ? 0 : 1); + + // Rank-reduce operands. + Location loc = convOp.getLoc(); + Value newInput = tensor::createCanonicalRankReducingExtractSliceOp( + rewriter, loc, input, newInputType); + Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp( + rewriter, loc, kernel, newKernelType); + Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp( + rewriter, loc, output, newOutputType); + + auto conv1DOp = rewriter.create(loc, newOutputType, + ValueRange{newInput, newKernel}, + ValueRange{newOutput}); + + // Insert back. + Value inserted = tensor::createCanonicalRankReducingInsertSliceOp( + rewriter, loc, conv1DOp.getResult(0), output); + rewriter.replaceOp(convOp, inserted); + + return conv1DOp; +} + void linalg::populateDecomposeConvolutionPatterns(RewritePatternSet &patterns, PatternBenefit benefit) { patterns.add, DownscaleSizeOneWindowed2DConvolution, - DownscaleDepthwiseConv2DNhwcHwcOp>(patterns.getContext(), - benefit); + DownscaleDepthwiseConv2DNhwcHwcOp, DownscaleConv2DOp>( + patterns.getContext(), benefit); patterns.add< DownscaleSizeOneWindowed2DConvolution, DownscaleSizeOneWindowed2DConvolution, Index: mlir/test/Dialect/Linalg/transform-op-decompose.mlir =================================================================== --- mlir/test/Dialect/Linalg/transform-op-decompose.mlir +++ mlir/test/Dialect/Linalg/transform-op-decompose.mlir @@ -56,6 +56,23 @@ return %0: tensor<1x1x56x96xf32> } +// CHECK-LABEL: @conv_2d +// CHECK-SAME: (%[[ARG0:[0-9a-z]+]]: tensor<1x?xf32>, +// CHECK-SAME: %[[ARG1:[0-9a-z]+]]: tensor<1x?xf32>, +// CHECK-SAME: %[[ARG2:[0-9a-z]+]]: tensor<1x?xf32>) +func.func @conv_2d(%input: tensor<1x?xf32>, %filter: tensor<1x?xf32>, %init: tensor<1x?xf32>) -> tensor<1x?xf32> { + // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]] + // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]] + // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]] + // CHECK: %[[SLICERES:.+]] = linalg.conv_1d + // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]] + %0 = linalg.conv_2d + ins (%input, %filter: tensor<1x?xf32>, tensor<1x?xf32>) + outs (%init: tensor<1x?xf32>) -> tensor<1x?xf32> + // CHECK: return %[[RES]] + return %0 : tensor<1x?xf32> +} + // CHECK-LABEL: @pooling_nhwc_sum // CHECK-SAME: %[[ARG0:.+]]: tensor, // CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xf32>