diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -1414,6 +1414,8 @@ No data conversion happens during a reshape operation. }]; + let hasCanonicalizer = 1; + let arguments = (ins Tosa_Tensor:$input1, I64ArrayAttr:$new_shape diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -638,7 +638,8 @@ if (newShape.size() != rank) { operand = rewriter.create( - loc, RankedTensorType::get(newShape, type.getElementType()), operand); + loc, RankedTensorType::get(newShape, type.getElementType()), operand, + rewriter.getI64ArrayAttr(newShape)); } operands.push_back(operand); diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/Tosa/Utils/QuantUtils.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/FoldUtils.h" #include "mlir/Transforms/InliningUtils.h" #include "mlir/Transforms/RegionUtils.h" @@ -101,6 +102,48 @@ return nullptr; } +//===----------------------------------------------------------------------===// +// Operator Canonicalizers. +//===----------------------------------------------------------------------===// + +struct RemoveReshapeNoop : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tosa::ReshapeOp op, + PatternRewriter &rewriter) const override { + if (op.input1().getType() != op.getType()) + return failure(); + + rewriter.replaceOp(op, op.input1()); + return success(); + } +}; + +struct ReshapeReshapeOptimization : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tosa::ReshapeOp op, + PatternRewriter &rewriter) const override { + Value input = op.input1(); + Operation *definingOp = input.getDefiningOp(); + if (!definingOp) + return failure(); + + if (tosa::ReshapeOp reshapeOp = dyn_cast(definingOp)) { + rewriter.replaceOpWithNewOp( + op, op.getType(), reshapeOp.input1(), op.new_shape()); + return success(); + } + + return failure(); + } +}; + +void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + //===----------------------------------------------------------------------===// // Operator Folders. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp --- a/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp @@ -143,7 +143,8 @@ SmallVector reshapeOutputShape; - computeReshapeOutput(higherRankShape, lowerRankShape, reshapeOutputShape); + computeReshapeOutput(outputType.getShape(), lowerRankShape, + reshapeOutputShape); auto reshapeInputType = lowerTensorValue.getType().cast(); auto reshapeOutputType = RankedTensorType::get( diff --git a/mlir/test/Dialect/Tosa/broadcast.mlir b/mlir/test/Dialect/Tosa/broadcast.mlir --- a/mlir/test/Dialect/Tosa/broadcast.mlir +++ b/mlir/test/Dialect/Tosa/broadcast.mlir @@ -136,6 +136,15 @@ return %0 : tensor<14x15xf32> } +// ----- +// CHECK-LABEL: broadcast19 +func @broadcast19(%arg0: tensor<64x64x1xf32>, %arg1: tensor<1x17xf32>) -> (tensor<64x64x17xf32> ) { + // CHECK: reshape + // CHECK: sub + %0 = "tosa.sub"(%arg0, %arg1) : (tensor<64x64x1xf32>, tensor<1x17xf32>) -> tensor<64x64x17xf32> + return %0 : tensor<64x64x17xf32> +} + // ----- // CHECK-LABEL: broadcast_mul func @test_broadcast_mul(%arg0: tensor<15x14xi32>, %arg1: tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32> {