diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -3045,6 +3045,119 @@ return success(); } }; + +static llvm::SmallVector getIndicesVector(int start, int end) { + return llvm::to_vector<2>(llvm::seq(start, end)); +} + +LogicalResult matchAndReplaceDepthwiseConv(Operation *operation, Value input, + Value kernel, Value iZp, Value kZp, + Value init, Attribute stride, + Attribute dilation, + PatternRewriter &rewriter) { + Location loc = operation->getLoc(); + auto linalgOp = dyn_cast(operation); + // Exit out on the memref version of this operation. + if (!linalgOp || !linalgOp.hasTensorSemantics()) + return failure(); + + auto result = operation->getResult(0); + + auto kernelTy = kernel.getType().dyn_cast(); + auto initTy = init.getType().dyn_cast(); + auto resultTy = result.getType().template dyn_cast(); + if (!kernelTy || !initTy || !resultTy) + return failure(); + + if (kernelTy.getDimSize(3) != 1) + return failure(); + + // Collapse kernel dims. + SmallVector collapsedKernelDims = { + getIndicesVector(0, 1), getIndicesVector(1, 2), getIndicesVector(2, 4)}; + auto newKernelTy = RankedTensorType::get( + {kernelTy.getDimSize(0), kernelTy.getDimSize(1), kernelTy.getDimSize(2)}, + kernelTy.getElementType()); + auto collapsedKernel = rewriter.create( + loc, newKernelTy, kernel, collapsedKernelDims); + + // Collapse init dims. + SmallVector collapsedInitDims = { + getIndicesVector(0, 1), getIndicesVector(1, 2), getIndicesVector(2, 3), + getIndicesVector(3, 5)}; + auto newInitTy = + RankedTensorType::get({initTy.getDimSize(0), initTy.getDimSize(1), + initTy.getDimSize(2), initTy.getDimSize(3)}, + initTy.getElementType()); + auto collapsedInit = rewriter.create( + loc, newInitTy, init, collapsedInitDims); + + Value newConv; + if (isa(operation)) { + newConv = rewriter + .create( + loc, newInitTy, ValueRange{input, collapsedKernel}, + ValueRange{collapsedInit}, stride, dilation) + .getResult(0); + } else if (isa(operation)) { + newConv = + rewriter + .create( + loc, newInitTy, ValueRange{input, collapsedKernel, iZp, kZp}, + ValueRange{collapsedInit}, stride, dilation) + .getResult(0); + } + + if (!newConv) + return failure(); + + // Expand dimensions back out to + rewriter.replaceOpWithNewOp( + operation, resultTy, newConv, collapsedInitDims); + return success(); +} + +struct SimplifyDepthwiseConvOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(DepthwiseConv2DNhwcOp op, + PatternRewriter &rewriter) const override { + Operation *operation = op.getOperation(); + Value input = op.getInputOperand(0)->get(); + Value kernel = op.getInputOperand(1)->get(); + Value init = op.getOutputOperand(0)->get(); + + auto stride = op.strides(); + auto dilation = op.dilations(); + + return matchAndReplaceDepthwiseConv(operation, input, kernel, nullptr, + nullptr, init, stride, dilation, + rewriter); + } +}; + +struct SimplifyDepthwiseConvQOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(DepthwiseConv2DNhwcQOp op, + PatternRewriter &rewriter) const override { + Operation *operation = op.getOperation(); + Value input = op.getInputOperand(0)->get(); + Value kernel = op.getInputOperand(1)->get(); + Value iZp = op.getInputOperand(2)->get(); + Value kZp = op.getInputOperand(3)->get(); + Value init = op.getOutputOperand(0)->get(); + + auto stride = op.strides(); + auto dilation = op.dilations(); + + return matchAndReplaceDepthwiseConv(operation, input, kernel, iZp, kZp, + init, stride, dilation, rewriter); + } +}; + } // namespace #define LINALGOP_FOLDERS(XXX) \ @@ -3070,5 +3183,6 @@ void LinalgDialect::getCanonicalizationPatterns( RewritePatternSet &results) const { - results.add(getContext()); + results.add(getContext()); } diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -1004,3 +1004,27 @@ return %r2 : index } +// ----- + +// CHECK-LABEL: @depthwise_conv +func @depthwise_conv(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + // CHECK-DAG: %[[KERNEL:.+]] = linalg.tensor_collapse_shape %arg1 {{\[\[}}0], [1], [2, 3]] + // CHECK-DAG: %[[INIT:.+]] = linalg.tensor_collapse_shape %arg2 {{\[\[}}0], [1], [2], [3, 4]] + // CHECK-DAG: %[[CONV:.+]] = linalg.depthwise_conv2D_nhw {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %[[KERNEL]] : tensor, tensor) outs(%[[INIT]] : tensor) + // CHECK: %[[OUT:.+]] = linalg.tensor_expand_shape %[[CONV]] {{\[\[}}0], [1], [2], [3, 4]] + %0 = linalg.depthwise_conv2D_nhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %arg1 : tensor, tensor) outs(%arg2 : tensor) -> tensor + return %0 : tensor +} + + +// ----- + +// CHECK-LABEL: @depthwise_conv_q +func @depthwise_conv_q(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3 : i32, %arg4 : i32) -> tensor { + // CHECK-DAG: %[[KERNEL:.+]] = linalg.tensor_collapse_shape %arg1 {{\[\[}}0], [1], [2, 3]] + // CHECK-DAG: %[[INIT:.+]] = linalg.tensor_collapse_shape %arg2 {{\[\[}}0], [1], [2], [3, 4]] + // CHECK-DAG: %[[CONV:.+]] = linalg.depthwise_conv2D_nhw_q {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %[[KERNEL]], %arg3, %arg4 : tensor, tensor, i32, i32) outs(%[[INIT]] : tensor) + // CHECK: %[[OUT:.+]] = linalg.tensor_expand_shape %[[CONV]] {{\[\[}}0], [1], [2], [3, 4]] + %0 = linalg.depthwise_conv2D_nhwc_q {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %arg1, %arg3, %arg4 : tensor, tensor, i32, i32) outs(%arg2 : tensor) -> tensor + return %0 : tensor +}