diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -105,11 +105,9 @@ "Used more than once or not-splat"); // Build new const op with correct output shape - ShapedType inputShape = input.getType().cast(); - DenseElementsAttr outputAttr = - inputAttr.reshape(inputShape.clone(op.getNewShape())); - rewriter.replaceOpWithNewOp(op, outputAttr.getType(), - outputAttr); + DenseElementsAttr outputAttr = inputAttr.reshape( + inputAttr.getType().cast().clone(op.getNewShape())); + rewriter.replaceOpWithNewOp(op, resultTy, outputAttr); return success(); } }; diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -364,6 +364,22 @@ return %0 , %1 : tensor<3xi32>, tensor<1x3xi32> } +// CHECK-LABEL: @reshape_canonicalize_quant +func.func @reshape_canonicalize_quant() -> (tensor<1x3x!quant.uniform>) { + //CHECK{literal}: "tosa.const"() {value = dense<[[1, 2, 3]]> : tensor<3xi8>} : ()-> tensor<3x!quant.uniform> + %0 = "tosa.const"() {value = dense<[1, 2, 3]> : tensor<3xi8>} : ()-> tensor<3x!quant.uniform> + %1 = "tosa.reshape"(%0) {new_shape = array} : (tensor<3x!quant.uniform>) -> tensor<1x3x!quant.uniform> + return %1 : tensor<1x3x!quant.uniform> +} + +// CHECK-LABEL: @reshape_canonicalize_strip_quant +func.func @reshape_canonicalize_strip_quant() -> (tensor<1x3xi8>) { + //CHECK{literal}: "tosa.const"() {value = dense<[[1, 2, 3]]> : tensor<3xi8>} : ()-> tensor<3x!quant.uniform> + %0 = "tosa.const"() {value = dense<[1, 2, 3]> : tensor<3xi8>} : ()-> tensor<3x!quant.uniform> + %1 = "tosa.reshape"(%0) {new_shape = array} : (tensor<3x!quant.uniform>) -> tensor<1x3xi8> + return %1 : tensor<1x3xi8> +} + // CHECK-LABEL: @slice_fold func.func @slice_fold(%arg0: tensor<3x4xf32>) -> tensor<3x4xf32> { // CHECK: return %arg0