diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -154,9 +154,48 @@ } }; +struct ReshapeConstOptimization : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tosa::ReshapeOp op, + PatternRewriter &rewriter) const override { + Value input = op.input1(); + ArrayAttr new_shape = op.new_shape(); + + // Check if input is constant + DenseElementsAttr inputAttr; + if (!matchPattern(input, m_Constant(&inputAttr))) { + return failure(); + } + + // Check if has >1 consumer and is not splat + if (!input.hasOneUse() && !inputAttr.isSplat()) { + return failure(); + } + + // Grab the new shape + SmallVector new_shape_values = llvm::to_vector<6>( + llvm::map_range(new_shape.getValue(), [](const Attribute &val) { + return val.cast().getValue().getSExtValue(); + })); + + // Build new const op with correct output shape + ShapedType inputShape = input.getType().cast(); + DenseElementsAttr outputAttr = + inputAttr.reshape(inputShape.clone(new_shape_values)); + ConstOp outputOp = rewriter.create( + op.getLoc(), outputAttr.getType(), outputAttr); + + // Replace op with new const op result + rewriter.replaceOp(op, outputOp.output()); + return success(); + } +}; + void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { results.insert(context); + results.insert(context); } struct ConstantTransposeOptimization diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -174,6 +174,39 @@ // ----- +// CHECK-LABEL: @reshape_canonicalize_const +func @reshape_canonicalize_const() -> tensor<1x10xi32> { + // CHECK: %[[VAR0:.+]] = "tosa.const"() {value = dense<0> : tensor<1x10xi32>} + // CHECK: return %[[VAR0]] + %0 = "tosa.const"() {value = dense<0> : tensor<10xi32>} : () -> tensor<10xi32> + %1 = "tosa.reshape"(%0) {new_shape = [1, 10]} : (tensor<10xi32>) -> tensor<1x10xi32> + return %1 : tensor<1x10xi32> +} + +// ----- + +// CHECK-LABEL: @reshape_canonicalize_const_spat +func @reshape_canonicalize_const_spat() -> (tensor<10xi32>, tensor<1x10xi32>) { + // CHECK-DAG: %[[VAR0:.+]] = "tosa.const"() {value = dense<0> : tensor<10xi32>} + // CHECK-DAG: %[[VAR1:.+]] = "tosa.const"() {value = dense<0> : tensor<1x10xi32>} + // CHECK: return %[[VAR0]], %[[VAR1]] + %0 = "tosa.const"() {value = dense<0> : tensor<10xi32>} : () -> tensor<10xi32> + %1 = "tosa.reshape"(%0) {new_shape = [1, 10]} : (tensor<10xi32>) -> tensor<1x10xi32> + return %0 , %1 : tensor<10xi32>, tensor<1x10xi32> +} + +// ----- + +// CHECK-LABEL: @reshape_canonicalize_const_sparse +func @reshape_canonicalize_const_sparse() -> (tensor<3xi32>, tensor<1x3xi32>) { + //CHECK: "tosa.reshape" + %0 = "tosa.const"() {value = dense<[1, 2, 3]> : tensor<3xi32>} : ()-> tensor<3xi32> + %1 = "tosa.reshape"(%0) {new_shape = [1, 3]} : (tensor<3xi32>) -> tensor<1x3xi32> + return %0 , %1 : tensor<3xi32>, tensor<1x3xi32> +} + +// ----- + // CHECK-LABEL: @slice_fold func @slice_fold(%arg0: tensor<3x4xf32>) -> tensor<3x4xf32> { // CHECK: return %arg0