diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -423,8 +423,7 @@ results.insert(context); } -struct Conv2DFullyConnectedOptimization - : public OpRewritePattern { +struct Conv2DIsFullyConnected : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(tosa::Conv2DOp op, @@ -439,6 +438,12 @@ return failure(); } + for (Attribute pad : op.pad().getValue()) { + if (!pad.cast().getValue().isZero()) { + return failure(); + } + } + // Stride must be 1 for this optimization. for (Attribute stride : op.stride().getValue()) { if (!stride.cast().getValue().isOne()) { @@ -456,9 +461,8 @@ ArrayRef inputShape = inputType.getShape(); llvm::SmallVector revisedInputShape{ inputShape[0] * inputShape[1] * inputShape[2], inputShape[3]}; - auto revisedInputShapeType = RankedTensorType::get( - revisedInputShape, - input.getType().dyn_cast().getElementType()); + auto revisedInputShapeType = + RankedTensorType::get(revisedInputShape, inputType.getElementType()); auto reshapedInput = rewriter .create( op.getLoc(), revisedInputShapeType, input, @@ -468,9 +472,8 @@ // Reshape kernel to [OC,KH,KW,IC] -> [OC, IC]. llvm::SmallVector revisedWeightShape{weightShape[0], weightShape[3]}; - auto revisedWeightShapeType = RankedTensorType::get( - revisedWeightShape, - weight.getType().dyn_cast().getElementType()); + auto revisedWeightShapeType = + RankedTensorType::get(revisedWeightShape, weightType.getElementType()); auto reshapedWeight = rewriter .create( op.getLoc(), revisedWeightShapeType, weight, @@ -480,9 +483,8 @@ // Perform a fully connected network over the reshaped input and weight. llvm::SmallVector fullyConnectedShape{ inputShape[0] * inputShape[1] * inputShape[2], weightShape[0]}; - auto fullyConnectedShapeType = RankedTensorType::get( - fullyConnectedShape, - resultType.dyn_cast().getElementType()); + auto fullyConnectedShapeType = + RankedTensorType::get(fullyConnectedShape, resultType.getElementType()); Value fullyConnectedValue; if (op.quantization_info()) { @@ -512,7 +514,7 @@ void Conv2DOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - results.insert(context); + results.insert(context); } struct DepthwiseConv2DMulOptimization diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -105,6 +105,15 @@ // ----- +// CHECK-LABEL: @conv2d_padded +func @conv2d_padded(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<3x1x1x2xf32>, %arg2: tensor<3xf32>) -> tensor<4x12x12x3xf32> { + // CHECK: "tosa.conv2d" + %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {pad = [1, 1, 1, 1], stride = [1, 1], dilation = [1, 1]} : (tensor<4x10x10x2xf32>, tensor<3x1x1x2xf32>, tensor<3xf32>) -> tensor<4x12x12x3xf32> + return %0 : tensor<4x12x12x3xf32> +} + +// ----- + // CHECK-LABEL: @conv2d_stride_2 func @conv2d_stride_2(%arg0: tensor<4x10x10x2xf32>) -> tensor<4x10x10x3xf32> { // CHECK: "tosa.conv2d"