diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -526,12 +526,18 @@ ShapedType inputType = input.getType().cast(); ShapedType weightType = weight.getType().cast(); ShapedType resultType = op.output().getType().cast(); + Type inputEType = inputType.getElementType(); if (!(inputType.hasStaticShape() && weightType.hasStaticShape() && resultType.hasStaticShape())) { return failure(); } + // Quantization information needs to still be performed. + if (op.quantization_info() || !inputEType.isa()) { + return failure(); + } + // Stride must be 1 for this optimization. for (Attribute stride : op.stride().getValue()) { if (!stride.cast().getValue().isOne()) { diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -128,6 +128,15 @@ // ----- +// CHECK-LABEL: @depthwise_conv2d_as_mul_q +func @depthwise_conv2d_as_mul_q(%arg0: tensor<4x10x10x2xi8>, %arg1: tensor<1x1x2x3xi8>, %arg2: tensor<6xi32>) -> tensor<4x10x10x6xi32> { + // CHECK: "tosa.depthwise_conv2d" + %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], quantization_info = {input_zp = 0 : i32, weight_zp = 0 : i32}} : (tensor<4x10x10x2xi8>, tensor<1x1x2x3xi8>, tensor<6xi32>) -> tensor<4x10x10x6xi32> + return %0 : tensor<4x10x10x6xi32> +} + +// ----- + // CHECK-LABEL: @depthwise_conv2d_stride_2 func @depthwise_conv2d_stride_2(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<1x1x2x3xf32>, %arg2: tensor<6xf32>) -> tensor<4x10x10x6xf32> { // CHECK: "tosa.depthwise_conv2d"