diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -187,6 +187,8 @@ let builders = [Tosa_ConvOpQuantInfoBuilder]; let verifier = [{ return verifyConvOp(*this); }]; + + let hasCanonicalizer = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -515,6 +515,97 @@ results.insert(context); } +struct DepthwiseConv2DMulOptimization + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tosa::DepthwiseConv2DOp op, + PatternRewriter &rewriter) const override { + Value input = op.input(); + Value weight = op.weight(); + ShapedType inputType = input.getType().cast(); + ShapedType weightType = weight.getType().cast(); + ShapedType resultType = op.output().getType().cast(); + + if (!(inputType.hasStaticShape() && weightType.hasStaticShape() && + resultType.hasStaticShape())) { + return failure(); + } + + // Stride must be 1 for this optimization. + for (Attribute stride : op.stride().getValue()) { + if (!stride.cast().getValue().isOne()) { + return failure(); + } + } + + // Only works for a 1x1 kernel. + ArrayRef weightShape = weightType.getShape(); + if (weightShape[0] != 1 || weightShape[1] != 1) { + return failure(); + } + + // Reshape input to [N, H, W, C] -> [N, H, W, C, 1]. + ArrayRef inputShape = inputType.getShape(); + llvm::SmallVector revisedInputShape{ + inputShape[0], inputShape[1], inputShape[2], inputShape[3], 1}; + auto revisedInputShapeType = RankedTensorType::get( + revisedInputShape, + input.getType().dyn_cast().getElementType()); + auto reshapedInput = rewriter + .create( + op.getLoc(), revisedInputShapeType, input, + rewriter.getI64ArrayAttr(revisedInputShape)) + .getResult(); + + // Reshape kernel to [KH, KW, C, M] -> [1, 1, 1, C, M]. + llvm::SmallVector revisedWeightShape{1, 1, 1, weightShape[2], + weightShape[3]}; + auto revisedWeightShapeType = RankedTensorType::get( + revisedWeightShape, + weight.getType().dyn_cast().getElementType()); + auto reshapedWeight = rewriter + .create( + op.getLoc(), revisedWeightShapeType, weight, + rewriter.getI64ArrayAttr(revisedWeightShape)) + .getResult(); + + // Perform an elementwise mul over the reshaped input and weight. + llvm::SmallVector mulShape{inputShape[0], inputShape[1], + inputShape[2], inputShape[3], + weightShape[3]}; + auto mulShapeType = RankedTensorType::get( + mulShape, + weight.getType().dyn_cast().getElementType()); + Value mulValue = + rewriter + .create(op.getLoc(), mulShapeType, reshapedInput, + reshapedWeight, /*shift=*/0) + .getResult(); + + // Reshape output to [N, H, W, C * M]. + auto outputShape = op.output().getType().cast().getShape(); + auto outputShapeType = RankedTensorType::get( + outputShape, + input.getType().dyn_cast().getElementType()); + auto outputValue = + rewriter.create(op.getLoc(), outputShapeType, mulValue, + rewriter.getI64ArrayAttr(outputShape)); + + // Add in the bias. + rewriter + .replaceOpWithNewOp(op, outputShapeType, outputValue, + op.bias()) + .getResult(); + return success(); + } +}; + +void DepthwiseConv2DOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert(context); +} + //===----------------------------------------------------------------------===// // Operator Folders. //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -106,6 +106,44 @@ return %0 : tensor<4x10x10x1xf32> } +// ----- + +// CHECK-LABEL: @depthwise_conv2d_as_mul +func @depthwise_conv2d_as_mul(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<1x1x2x3xf32>, %arg2: tensor<6xf32>) -> tensor<4x10x10x6xf32> { + // CHECK-NOT: "tosa.depthwise_conv2d" + // CHECK: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = [4, 10, 10, 2, 1]} + // CHECK-SAME: -> tensor<4x10x10x2x1xf32> + // CHECK: %[[VAR1:.*]] = "tosa.reshape"(%arg1) {new_shape = [1, 1, 1, 2, 3]} + // CHECK-SAME: -> tensor<1x1x1x2x3xf32> + // CHECK: %[[VAR2:.*]] = "tosa.mul"(%[[VAR0]], %[[VAR1]]) + // CHECK-SAME: -> tensor<4x10x10x2x3xf32> + // CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = [4, 10, 10, 6]} + // CHECK-SAME: -> tensor<4x10x10x6xf32> + // CHECK: %[[VAR4:.*]] = "tosa.add"(%[[VAR3]], %arg2) + // CHECK-SAME: -> tensor<4x10x10x6xf32> + // CHECK: return %[[VAR4]] + %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1]} : (tensor<4x10x10x2xf32>, tensor<1x1x2x3xf32>, tensor<6xf32>) -> tensor<4x10x10x6xf32> + return %0 : tensor<4x10x10x6xf32> +} + +// ----- + +// CHECK-LABEL: @depthwise_conv2d_stride_2 +func @depthwise_conv2d_stride_2(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<1x1x2x3xf32>, %arg2: tensor<6xf32>) -> tensor<4x10x10x6xf32> { + // CHECK: "tosa.depthwise_conv2d" + %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [2, 2], dilation = [1, 1]} : (tensor<4x10x10x2xf32>, tensor<1x1x2x3xf32>, tensor<6xf32>) -> tensor<4x10x10x6xf32> + return %0 : tensor<4x10x10x6xf32> +} + +// ----- + +// CHECK-LABEL: @depthwise_conv2d_weight_2x2 +func @depthwise_conv2d_weight_2x2(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<2x2x2x3xf32>, %arg2: tensor<6xf32>) -> tensor<4x10x10x6xf32> { + // CHECK: "tosa.depthwise_conv2d" + %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1]} : (tensor<4x10x10x2xf32>, tensor<2x2x2x3xf32>, tensor<6xf32>) -> tensor<4x10x10x6xf32> + return %0 : tensor<4x10x10x6xf32> +} + // ---- // CHECK-LABEL: @pad_noop