diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -46,7 +46,15 @@ MLIRContext *context, SmallVectorImpl &patterns, ArrayRef tileSizes); -/// Populates patterns for vectorizing convolution ops. +/// Populates patterns to decompose high-D convolution ops into low-D ones. This +/// is a step in progressive lowering for convolution ops, afterwards we can +/// vectorize the low-D convolution ops. +void populateDecomposeConvolutionPatterns(RewritePatternSet &patterns, + PatternBenefit benefit = 1); + +/// Populates patterns for vectorizing low-D convolution ops. This is a step in +/// progressive lowering for convolution ops, it assume high-D convolution ops +/// were decomposed previously. void populateConvolutionVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit = 1); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -780,3 +780,98 @@ rewriter.replaceOp(sliceOp, tiledPadOp->getResults()); return success(); } + +namespace { +// The following are patterns for downscaling convolution ops with size-1 +// window dimensions. +// +// Note that we'd eventually want to write such transformations in a generic +// way, e.g., converting to linalg.generic, removing the size-1 dimensions, +// and then turning back to named ops. But for now it's fine to have a few +// patterns matching special ops to get started. + +/// Rewrites 2-D convolution ops with size-1 window dimensions into 1-D +/// convolution ops. +struct DownscaleSizeOneWindowed2DConvolution final + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp, + PatternRewriter &rewriter) const override { + auto linalgOp = cast(*convOp); + if (linalgOp.hasBufferSemantics()) + return failure(); // To be implemented + + Value input = convOp.inputs().front(); + Value filter = convOp.inputs().back(); + Value output = convOp.outputs().front(); + + auto inputType = input.getType().dyn_cast(); + auto filterType = filter.getType().dyn_cast(); + auto outputType = output.getType().dyn_cast(); + + auto inputShape = inputType.getShape(); + auto filterShape = filterType.getShape(); + auto outputShape = outputType.getShape(); + + // Only handle the case where at least one of the window dimensions is + // of size 1. Other cases can rely on tiling to reduce to such cases. + int64_t fhSize = filterShape[0], fwSize = filterShape[1]; + int64_t ohSize = outputShape[1], owSize = outputShape[2]; + if (!(fhSize == 1 && ohSize == 1) && !(fwSize == 1 && owSize == 1)) + return failure(); + bool removeH = ohSize == 1; + + // Get new shapes and types for all operands by removing the size-1 + // dimension. + + SmallVector newInputShape{ + inputShape[0], inputShape[removeH ? 2 : 1], inputShape[3]}; + auto newInputType = RankedTensorType::get( + newInputShape, inputType.getElementType(), inputType.getEncoding()); + + SmallVector newFilterShape{filterShape[removeH ? 1 : 0], + filterShape[2], filterShape[3]}; + auto newFilterType = RankedTensorType::get( + newFilterShape, filterType.getElementType(), filterType.getEncoding()); + + SmallVector newOutputShape{ + outputShape[0], outputShape[removeH ? 2 : 1], outputShape[3]}; + auto newOutputType = RankedTensorType::get( + newOutputShape, outputType.getElementType(), outputType.getEncoding()); + + SmallVector ioReshapeIndices = {{0}, {1, 2}, {3}}; + SmallVector fReshapeIndices = {{0, 1}, {2}, {3}}; + + // Reshape all operands for 1-D convolution. + Location loc = convOp.getLoc(); + Value newInput = rewriter.create( + loc, newInputType, input, ioReshapeIndices); + Value newFilter = rewriter.create( + loc, newFilterType, filter, fReshapeIndices); + Value newOutput = rewriter.create( + loc, newOutputType, output, ioReshapeIndices); + + // We need to shrink the strides and dilations too. + auto stride = convOp.strides().getFlatValue(removeH ? 1 : 0); + auto stridesAttr = rewriter.getI64VectorAttr(stride); + auto dilation = convOp.dilations().getFlatValue(removeH ? 1 : 0); + auto dilationsAttr = rewriter.getI64VectorAttr(dilation); + + auto conv1DOp = rewriter.create( + loc, newOutputType, ValueRange{newInput, newFilter}, + ValueRange{newOutput}, stridesAttr, dilationsAttr); + + rewriter.replaceOpWithNewOp( + convOp, outputType, conv1DOp.getResult(0), ioReshapeIndices); + return success(); + }; +}; + +} // namespace + +void linalg::populateDecomposeConvolutionPatterns(RewritePatternSet &patterns, + PatternBenefit benefit) { + patterns.add(patterns.getContext(), + benefit); +} diff --git a/mlir/test/Dialect/Linalg/decompose-convolution.mlir b/mlir/test/Dialect/Linalg/decompose-convolution.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/decompose-convolution.mlir @@ -0,0 +1,67 @@ +// RUN: mlir-opt -split-input-file -test-linalg-transform-patterns=test-decompose-convolution-patterns %s | FileCheck %s + +// CHECK-LABEL: func @conv2d_nhwc_4x1x2x8_tensor +// CHECK-SAME: (%[[INPUT:.+]]: tensor<4x1x6x3xf32>, %[[FILTER:.+]]: tensor<1x2x3x8xf32>, %[[INIT:.+]]: tensor<4x1x2x8xf32>) +func @conv2d_nhwc_4x1x2x8_tensor(%input: tensor<4x1x6x3xf32>, %filter: tensor<1x2x3x8xf32>, %init: tensor<4x1x2x8xf32>) -> tensor<4x1x2x8xf32> { + %0 = linalg.conv_2d_nhwc_hwcf + {dilations = dense<[2, 3]> : tensor<2xi64>, strides = dense<[3, 2]> : tensor<2xi64>} + ins(%input, %filter : tensor<4x1x6x3xf32>, tensor<1x2x3x8xf32>) + outs(%init : tensor<4x1x2x8xf32>) -> tensor<4x1x2x8xf32> + return %0 : tensor<4x1x2x8xf32> +} + +// CHECK: %[[INPUT_1D:.+]] = linalg.tensor_collapse_shape %[[INPUT]] +// CHECK-SAME{LITERAL}: [[0], [1, 2], [3]] : tensor<4x1x6x3xf32> into tensor<4x6x3xf32> +// CHECK: %[[FILTER_1D:.+]] = linalg.tensor_collapse_shape %[[FILTER]] +// CHECK-SAME{LITERAL}: [[0, 1], [2], [3]] : tensor<1x2x3x8xf32> into tensor<2x3x8xf32> +// CHECK: %[[INIT_1D:.+]] = linalg.tensor_collapse_shape %[[INIT]] +// CHECK-SAME{LITERAL}: [[0], [1, 2], [3]] : tensor<4x1x2x8xf32> into tensor<4x2x8xf32> +// CHECK: %[[CONV_1D:.+]] = linalg.conv_1d_nwc_wcf +// CHECK-SAME: dilations = dense<3> : vector<1xi64> +// CHECK-SAME: strides = dense<2> : vector<1xi64> +// CHECK-SAME: ins(%[[INPUT_1D]], %[[FILTER_1D]] : tensor<4x6x3xf32>, tensor<2x3x8xf32>) +// CHECK-SAME: outs(%[[INIT_1D]] : tensor<4x2x8xf32>) +// CHECK: %[[CONV_2D:.+]] = linalg.tensor_expand_shape %[[CONV_1D]] +// CHECK-SAME{LITERAL}: [[0], [1, 2], [3]] : tensor<4x2x8xf32> into tensor<4x1x2x8xf32> +// CHECK: return %[[CONV_2D]] + +// ----- + +// CHECK-LABEL: func @conv2d_nhwc_qxqx1xq_tensor +// CHECK-SAME: (%[[INPUT:.+]]: tensor, %[[FILTER:.+]]: tensor, %[[INIT:.+]]: tensor) +func @conv2d_nhwc_qxqx1xq_tensor(%input: tensor, %filter: tensor, %init: tensor) -> tensor { + %0 = linalg.conv_2d_nhwc_hwcf + {dilations = dense<[2, 3]> : tensor<2xi64>, strides = dense<[3, 2]> : tensor<2xi64>} + ins(%input, %filter : tensor, tensor) + outs(%init : tensor) -> tensor + return %0 : tensor +} + +// CHECK: %[[INPUT_1D:.+]] = linalg.tensor_collapse_shape %[[INPUT]] +// CHECK-SAME{LITERAL}: [[0], [1, 2], [3]] : tensor into tensor +// CHECK: %[[FILTER_1D:.+]] = linalg.tensor_collapse_shape %[[FILTER]] +// CHECK-SAME{LITERAL}: [[0, 1], [2], [3]] : tensor into tensor +// CHECK: %[[INIT_1D:.+]] = linalg.tensor_collapse_shape %[[INIT]] +// CHECK-SAME{LITERAL}: [[0], [1, 2], [3]] : tensor into tensor +// CHECK: %[[CONV_1D:.+]] = linalg.conv_1d_nwc_wcf +// CHECK-SAME: dilations = dense<2> : vector<1xi64> +// CHECK-SAME: strides = dense<3> : vector<1xi64> +// CHECK-SAME: ins(%[[INPUT_1D]], %[[FILTER_1D]] : tensor, tensor) +// CHECK-SAME: outs(%[[INIT_1D]] : tensor) +// CHECK: %[[CONV_2D:.+]] = linalg.tensor_expand_shape %[[CONV_1D]] +// CHECK-SAME{LITERAL}: [[0], [1, 2], [3]] : tensor into tensor +// CHECK: return %[[CONV_2D]] + +// ----- + +// Do not convert convolution ops whose window dimensions are not ones. + +// CHECK-LABEL: func @conv2d_nhwc_4x1x2x8_tensor +func @conv2d_nhwc_4x1x2x8_tensor(%input: tensor<4x3x5x3xf32>, %filter: tensor<2x2x3x8xf32>, %init: tensor<4x1x2x8xf32>) -> tensor<4x1x2x8xf32> { + // CHECK: linalg.conv_2d_nhwc_hwcf + %0 = linalg.conv_2d_nhwc_hwcf + {dilations = dense<[2, 3]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} + ins(%input, %filter : tensor<4x3x5x3xf32>, tensor<2x2x3x8xf32>) + outs(%init : tensor<4x1x2x8xf32>) -> tensor<4x1x2x8xf32> + return %0 : tensor<4x1x2x8xf32> +} diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp @@ -141,6 +141,11 @@ llvm::cl::desc("Specify the type of loops to generate: for, parallel or " "tiled_loop"), llvm::cl::init("for")}; + Option testDecomposeConvolutionPattern{ + *this, "test-decompose-convolution-patterns", + llvm::cl::desc("Test a set of patterns to rewrite high-D convolution ops " + "into low-D ones"), + llvm::cl::init(false)}; }; } // end anonymous namespace @@ -565,6 +570,12 @@ (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); } +static void applyDecomposeConvolutionPatterns(FuncOp funcOp) { + RewritePatternSet patterns(funcOp.getContext()); + populateDecomposeConvolutionPatterns(patterns); + (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); +} + static void applyPadTensorToGenericPatterns(FuncOp funcOp) { RewritePatternSet patterns(funcOp.getContext()); patterns.add(funcOp.getContext()); @@ -776,6 +787,8 @@ } if (testInterchangePattern.hasValue()) return applyInterchangePattern(getFunction(), testInterchangePattern); + if (testDecomposeConvolutionPattern) + return applyDecomposeConvolutionPatterns(getFunction()); } namespace mlir {