diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -105,11 +105,9 @@ "Used more than once or not-splat"); // Build new const op with correct output shape - ShapedType inputShape = input.getType().cast(); - DenseElementsAttr outputAttr = - inputAttr.reshape(inputShape.clone(op.getNewShape())); - rewriter.replaceOpWithNewOp(op, outputAttr.getType(), - outputAttr); + DenseElementsAttr outputAttr = inputAttr.reshape( + inputAttr.getType().cast().clone(op.getNewShape())); + rewriter.replaceOpWithNewOp(op, resultTy, outputAttr); return success(); } }; diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -358,12 +358,29 @@ // CHECK-LABEL: @reshape_canonicalize_const_sparse func.func @reshape_canonicalize_const_sparse() -> (tensor<3xi32>, tensor<1x3xi32>) { - //CHECK: "tosa.reshape" + // CHECK: "tosa.reshape" %0 = "tosa.const"() {value = dense<[1, 2, 3]> : tensor<3xi32>} : ()-> tensor<3xi32> %1 = "tosa.reshape"(%0) {new_shape = array} : (tensor<3xi32>) -> tensor<1x3xi32> return %0 , %1 : tensor<3xi32>, tensor<1x3xi32> } +// CHECK-LABEL: @reshape_canonicalize_quant +func.func @reshape_canonicalize_quant() -> (tensor<1x3x!quant.uniform>) { + // CHECK{literal}: "tosa.const"() {value = dense<[[1, 2, 3]]> : tensor<1x3xi8>} : () -> tensor<1x3x!quant.uniform> + %0 = "tosa.const"() {value = dense<[1, 2, 3]> : tensor<3xi8>} : ()-> tensor<3x!quant.uniform> + %1 = "tosa.reshape"(%0) {new_shape = array} : (tensor<3x!quant.uniform>) -> tensor<1x3x!quant.uniform> + return %1 : tensor<1x3x!quant.uniform> +} + +// CHECK-LABEL: @transpose_canonicalize_strip_quant +func.func @transpose_canonicalize_strip_quant() -> (tensor<2x1x3xi8>) { + // CHECK: "tosa.const"() {value = dense<0> : tensor<2x1x3xi8>} : () -> tensor<2x1x3xi8> + %perms = "tosa.const"() {value = dense<[1, 0, 2]> : tensor<3xi32>} : () -> tensor<3xi32> + %0 = "tosa.const"() {value = dense<0> : tensor<1x2x3xi8>} : ()-> tensor<1x2x3x!quant.uniform> + %1 = "tosa.transpose"(%0, %perms) : (tensor<1x2x3x!quant.uniform>, tensor<3xi32>) -> tensor<2x1x3xi8> + return %1 : tensor<2x1x3xi8> +} + // CHECK-LABEL: @slice_fold func.func @slice_fold(%arg0: tensor<3x4xf32>) -> tensor<3x4xf32> { // CHECK: return %arg0