diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -118,8 +118,6 @@ let builders = [Tosa_ConvOpQuantInfoBuilder]; let verifier = [{ return verifyConvOp(*this); }]; - - let hasCanonicalizer = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -423,100 +423,6 @@ results.insert(context); } -struct Conv2DFullyConnectedOptimization - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(tosa::Conv2DOp op, - PatternRewriter &rewriter) const override { - Value input = op.input(); - Value weight = op.weight(); - ShapedType inputType = input.getType().cast(); - ShapedType weightType = weight.getType().cast(); - - if (!inputType.hasStaticShape() || !weightType.hasRank()) { - return failure(); - } - - // Stride must be 1 for this optimization. - for (Attribute stride : op.stride().getValue()) { - if (!stride.cast().getValue().isOne()) { - return failure(); - } - } - - // Only works for a 1x1 kernel. - ArrayRef weightShape = weightType.getShape(); - if (weightShape[1] != 1 || weightShape[2] != 1) { - return failure(); - } - - // Reshape input to [N,IH,IW,IC] -> [N * IH * IW, IC]. - ArrayRef inputShape = inputType.getShape(); - llvm::SmallVector revisedInputShape{ - inputShape[0] * inputShape[1] * inputShape[2], inputShape[3]}; - auto revisedInputShapeType = RankedTensorType::get( - revisedInputShape, - input.getType().dyn_cast().getElementType()); - auto reshapedInput = rewriter - .create( - op.getLoc(), revisedInputShapeType, input, - rewriter.getI64ArrayAttr(revisedInputShape)) - .getResult(); - - // Reshape kernel to [OC,KH,KW,IC] -> [OC, IC]. - llvm::SmallVector revisedWeightShape{weightShape[0], - weightShape[3]}; - auto revisedWeightShapeType = RankedTensorType::get( - revisedWeightShape, - weight.getType().dyn_cast().getElementType()); - auto reshapedWeight = rewriter - .create( - op.getLoc(), revisedWeightShapeType, weight, - rewriter.getI64ArrayAttr(revisedWeightShape)) - .getResult(); - - // Perform a fully connected network over the reshaped input and weight. - llvm::SmallVector fullyConnectedShape{ - inputShape[0] * inputShape[1] * inputShape[2], inputShape[3]}; - auto fullyConnectedShapeType = RankedTensorType::get( - fullyConnectedShape, - weight.getType().dyn_cast().getElementType()); - - Value fullyConnectedValue; - if (op.quantization_info()) { - fullyConnectedValue = - rewriter - .create( - op.getLoc(), fullyConnectedShapeType, reshapedInput, - reshapedWeight, op.bias(), op.quantization_info().getValue()) - .getResult(); - } else { - fullyConnectedValue = rewriter - .create( - op.getLoc(), fullyConnectedShapeType, - reshapedInput, reshapedWeight, op.bias()) - .getResult(); - } - - // Reshape output to [N, IH, IW, OC]. - llvm::SmallVector outputShape{inputShape[0], inputShape[1], - inputShape[2], weightShape[0]}; - auto outputShapeType = RankedTensorType::get( - outputShape, - input.getType().dyn_cast().getElementType()); - rewriter.replaceOpWithNewOp( - op, outputShapeType, fullyConnectedValue, - rewriter.getI64ArrayAttr(outputShape)); - return success(); - } -}; - -void Conv2DOp::getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context) { - results.insert(context); -} - //===----------------------------------------------------------------------===// // Operator Folders. //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -66,48 +66,12 @@ return %0 : tensor } -// ----- - -// CHECK-LABEL: @conv2d_as_fully_connected -func @conv2d_as_fully_connected(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<3x1x1x2xf32>, %arg2: tensor<3xf32>) -> tensor<4x10x10x3xf32> { - // CHECK-NOT: "tosa.conv2d" - // CHECK: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = [400, 2]} - // CHECK: %[[VAR1:.*]] = "tosa.reshape"(%arg1) {new_shape = [3, 2]} - // CHECK: %[[VAR2:.*]] = "tosa.fully_connected"(%[[VAR0]], %[[VAR1]], %arg2) - // CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = [4, 10, 10, 3]} - // CHECK: return %[[VAR3]] - %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1]} : (tensor<4x10x10x2xf32>, tensor<3x1x1x2xf32>, tensor<3xf32>) -> tensor<4x10x10x3xf32> - return %0 : tensor<4x10x10x3xf32> -} - -// ----- - -// CHECK-LABEL: @conv2d_stride_2 -func @conv2d_stride_2(%arg0: tensor<4x10x10x2xf32>) -> tensor<4x10x10x3xf32> { - // CHECK: "tosa.conv2d" - %weight = "tosa.const"() {value = dense<[[[[1.0, 1.0]]], [[[1.0, 1.0]]], [[[1.0, 1.0]]]]> : tensor<3x1x1x2xf32>} : ()-> tensor<3x1x1x2xf32> - %bias = "tosa.const"() {value = dense<0.0> : tensor<3xf32>} : ()-> tensor<3xf32> - %0 = "tosa.conv2d"(%arg0, %weight, %bias) {pad = [0, 0, 0, 0], stride = [2, 2], dilation = [1, 1]} : (tensor<4x10x10x2xf32>, tensor<3x1x1x2xf32>, tensor<3xf32>) -> tensor<4x10x10x3xf32> - return %0 : tensor<4x10x10x3xf32> -} - -// ----- - -// CHECK-LABEL: @conv2d_weight_2x2 -func @conv2d_weight_2x2(%arg0: tensor<4x10x10x1xf32>) -> tensor<4x10x10x1xf32> { - // CHECK: "tosa.conv2d" - %weight = "tosa.const"() {value = dense<[[[[1.0], [1.0]], [[1.0], [1.0]]]]> : tensor<1x2x2x1xf32>} : ()-> tensor<1x2x2x1xf32> - %bias = "tosa.const"() {value = dense<0.0> : tensor<1xf32>} : ()-> tensor<1xf32> - %0 = "tosa.conv2d"(%arg0, %weight, %bias) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1]} : (tensor<4x10x10x1xf32>, tensor<1x2x2x1xf32>, tensor<1xf32>) -> tensor<4x10x10x1xf32> - return %0 : tensor<4x10x10x1xf32> -} - // ---- // CHECK-LABEL: @pad_noop func @pad_noop(%arg0: tensor) -> tensor { // CHECK: return %arg0 - %0 = "tosa.const"() { value = dense<0> : tensor<2x2xi32>} : () -> tensor<2x2xi32> + %0 = "tosa.const"() { value = dense<0> : tensor<2x2xi32>} : () -> tensor<2x2xi32> %1 = "tosa.pad"(%arg0, %0) : (tensor, tensor<2x2xi32>) -> tensor return %1 : tensor } @@ -118,7 +82,7 @@ func @pad_determine_val_i32(%arg0: tensor, %arg1 : tensor<2x2xi32>) -> tensor { // CHECK: %[[ZERO:.+]] = "tosa.const"() {value = dense<0> : tensor} // CHECK: "tosa.pad"(%arg0, %arg1, %[[ZERO]]) - %0 = "tosa.const"() { value = dense<[[1, 0], [0, 1]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32> + %0 = "tosa.const"() { value = dense<[[1, 0], [0, 1]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32> %1 = "tosa.pad"(%arg0, %arg1) : (tensor, tensor<2x2xi32>) -> tensor return %1 : tensor } @@ -129,7 +93,7 @@ func @pad_determine_val_f32(%arg0: tensor, %arg1 : tensor<2x2xi32>) -> tensor { // CHECK: %[[ZERO:.+]] = "tosa.const"() {value = dense<0.000000e+00> : tensor} // CHECK: "tosa.pad"(%arg0, %arg1, %[[ZERO]]) - %0 = "tosa.const"() { value = dense<[[1, 0], [0, 1]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32> + %0 = "tosa.const"() { value = dense<[[1, 0], [0, 1]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32> %1 = "tosa.pad"(%arg0, %arg1) : (tensor, tensor<2x2xi32>) -> tensor return %1 : tensor } @@ -140,7 +104,7 @@ func @pad_determine_val_quant(%arg0: tensor, %arg1 : tensor<2x2xi32>) -> tensor { // CHECK: %[[ZERO:.+]] = "tosa.const"() {value = dense<42> : tensor} // CHECK: "tosa.pad"(%arg0, %arg1, %[[ZERO]]) - %0 = "tosa.const"() { value = dense<[[1, 0], [0, 1]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32> + %0 = "tosa.const"() { value = dense<[[1, 0], [0, 1]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32> %1 = "tosa.pad"(%arg0, %arg1) { quantization_info = {input_zp = 42:i32} } : (tensor, tensor<2x2xi32>) -> tensor return %1 : tensor }