diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -82,40 +82,9 @@ } }; -struct ReshapeConstOptimization : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(tosa::ReshapeOp op, - PatternRewriter &rewriter) const override { - Value input = op.getInput1(); - ShapedType inputTy = llvm::cast(input.getType()); - ShapedType resultTy = llvm::cast(op.getType()); - - if (inputTy.getElementType() != resultTy.getElementType()) - return rewriter.notifyMatchFailure(op, "element type does not match."); - - // Check if input is constant - DenseElementsAttr inputAttr; - if (!matchPattern(input, m_Constant(&inputAttr))) - return rewriter.notifyMatchFailure(op, "Non-constant input."); - - // Check if has >1 consumer and is not splat - if (!input.hasOneUse() && !inputAttr.isSplat()) - return rewriter.notifyMatchFailure(op, - "Used more than once or not-splat"); - - // Build new const op with correct output shape - DenseElementsAttr outputAttr = inputAttr.reshape( - llvm::cast(inputAttr.getType()).clone(op.getNewShape())); - rewriter.replaceOpWithNewOp(op, resultTy, outputAttr); - return success(); - } -}; - void ReshapeOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add(context); - results.add(context); } LogicalResult SelectOp::canonicalize(SelectOp op, PatternRewriter &rewriter) { @@ -851,12 +820,25 @@ if (inputTy == outputTy) return getInput1(); + // Constants must have static shape. + if (!outputTy.hasStaticShape()) + return {}; + auto operand = llvm::dyn_cast_if_present(adaptor.getInput1()); - if (operand && outputTy.hasStaticShape() && operand.isSplat()) { + if (!operand) + return {}; + + // Okay to duplicate splat constants. + if (operand.isSplat()) { return SplatElementsAttr::get(outputTy, operand.getSplatValue()); } - return {}; + // Don't duplicate other constants. + if (!getInput1().hasOneUse()) + return {}; + + return operand.reshape( + llvm::cast(operand.getType()).clone(getNewShape())); } OpFoldResult PadOp::fold(FoldAdaptor adaptor) { diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -375,16 +375,24 @@ } // CHECK-LABEL: @reshape_canonicalize_const -func.func @reshape_canonicalize_const() -> tensor<1x10xi32> { - // CHECK: %[[VAR0:.+]] = "tosa.const"() <{value = dense<0> : tensor<1x10xi32>} +func.func @reshape_canonicalize_const() -> tensor<1x5xi32> { + // CHECK: %[[VAR0:.+]] = "tosa.const"() <{value = dense<{{\[\[}}0, 1, 2, 3, 4]]> : tensor<1x5xi32>} // CHECK: return %[[VAR0]] - %0 = "tosa.const"() {value = dense<0> : tensor<10xi32>} : () -> tensor<10xi32> - %1 = "tosa.reshape"(%0) {new_shape = array} : (tensor<10xi32>) -> tensor<1x10xi32> - return %1 : tensor<1x10xi32> + %0 = "tosa.const"() {value = dense<[0, 1, 2, 3, 4]> : tensor<5xi32>} : () -> tensor<5xi32> + %1 = "tosa.reshape"(%0) {new_shape = array} : (tensor<5xi32>) -> tensor<1x5xi32> + return %1 : tensor<1x5xi32> +} + +// CHECK-LABEL: @reshape_canonicalize_const_dynamic +func.func @reshape_canonicalize_const_dynamic() -> tensor<1x?xi32> { + // CHECK: tosa.reshape + %0 = "tosa.const"() {value = dense<[0, 1, 2, 3, 4]> : tensor<5xi32>} : () -> tensor<5xi32> + %1 = "tosa.reshape"(%0) {new_shape = array} : (tensor<5xi32>) -> tensor<1x?xi32> + return %1 : tensor<1x?xi32> } -// CHECK-LABEL: @reshape_canonicalize_const_spat -func.func @reshape_canonicalize_const_spat() -> (tensor<10xi32>, tensor<1x10xi32>) { +// CHECK-LABEL: @reshape_canonicalize_const_splat +func.func @reshape_canonicalize_const_splat() -> (tensor<10xi32>, tensor<1x10xi32>) { // CHECK-DAG: %[[VAR0:.+]] = "tosa.const"() <{value = dense<0> : tensor<10xi32>} // CHECK-DAG: %[[VAR1:.+]] = "tosa.const"() <{value = dense<0> : tensor<1x10xi32>} // CHECK: return %[[VAR0]], %[[VAR1]]