diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -16,8 +16,11 @@ #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include + using namespace mlir; static SmallVector getNParallelLoopsAttrs(unsigned nParallelLoops) { @@ -339,6 +342,106 @@ } }; +class ReshapeOpConverter : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(tosa::ReshapeOp reshape, ArrayRef args, + ConversionPatternRewriter &rewriter) const final { + typename tosa::ReshapeOp::Adaptor operands(args); + + ShapedType operandTy = operands.input1().getType().cast(); + ShapedType resultTy = reshape.getType().template cast(); + + if (!operandTy.hasStaticShape() || !resultTy.hasStaticShape()) + return failure(); + + // Compute the reassociation maps for the linalg operation. + ArrayRef expandedShape = + (operandTy.getRank() > resultTy.getRank() ? operandTy.getShape() + : resultTy.getShape()); + ArrayRef collapsedShape = + (operandTy.getRank() > resultTy.getRank() ? resultTy.getShape() + : operandTy.getShape()); + unsigned currSrcDim = 0, currDstDim = 0; + SmallVector reassociationMap( + collapsedShape.size()); + + // First scan all dimensions in the source shapes to see whether we have a + // perfect case where consecutive dimensions in source are collapsed. For + // such case we can just generate one single linalg.reshape. + bool isCollapsingSource = true; + while (currSrcDim < expandedShape.size() && + currDstDim < collapsedShape.size()) { + int64_t dstSize = collapsedShape[currDstDim]; + int64_t srcSize = expandedShape[currSrcDim]; + while (srcSize < dstSize && currSrcDim < expandedShape.size()) { + reassociationMap[currDstDim].push_back( + rewriter.getAffineDimExpr(currSrcDim++)); + srcSize *= expandedShape[currSrcDim]; + } + if (srcSize == dstSize) { + reassociationMap[currDstDim].push_back( + rewriter.getAffineDimExpr(currSrcDim++)); + // If the next dim in collapsedShape is not 1, treat subsequent dims in + // expandedShape which are 1 to be collapsed. + if (currDstDim == collapsedShape.size() - 1 || + collapsedShape[currDstDim + 1] != 1) { + while (currSrcDim < expandedShape.size() && + expandedShape[currSrcDim] == 1) { + reassociationMap[currDstDim].push_back( + rewriter.getAffineDimExpr(currSrcDim++)); + } + } + } else { + isCollapsingSource = false; + break; + } + currDstDim++; + } + if (currSrcDim != expandedShape.size() || + currDstDim != collapsedShape.size()) + isCollapsingSource = false; + + // Otherwise, we need to first reduce all source dimensions into one and + // then expand to the destination dimensions. + if (!isCollapsingSource) { + auto getIdentityExprs = [&rewriter](int n) { + SmallVector exprs; + for (int i = 0; i < n; ++i) + exprs.push_back(rewriter.getAffineDimExpr(i)); + return exprs; + }; + Location loc = reshape.getLoc(); + int64_t totalElems = + std::accumulate(expandedShape.begin(), expandedShape.end(), 1, + std::multiplies()); + auto elemTy = operandTy.getElementType(); + SmallVector collapsingMap = { + // Use operandTy here because we need to collapse all operands + // dimensions. + getIdentityExprs(operandTy.getShape().size())}; + SmallVector expandingMap = { + // Use resultTy here because we need to expand to all result + // dimensions. + getIdentityExprs(resultTy.getShape().size())}; + + auto collapsedTy = RankedTensorType::get({totalElems}, elemTy); + Value collapsedOp = rewriter.create( + loc, collapsedTy, args[0], collapsingMap); + rewriter.replaceOpWithNewOp( + reshape, resultTy, collapsedOp, expandingMap); + + return success(); + } + + rewriter.replaceOpWithNewOp( + reshape, resultTy, args[0], reassociationMap); + return success(); + } +}; + } // namespace void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns( @@ -358,6 +461,6 @@ PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, - PointwiseConverter, PointwiseConverter>( - context); + PointwiseConverter, PointwiseConverter, + ReshapeOpConverter>(context); } diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -258,3 +258,49 @@ return } +// ----- + +// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-LABEL: @test_reshape_downrank +func @test_reshape_downrank(%arg0: tensor<2x3xf32>) -> tensor<6xf32> { + // CHECK: [[RESHAPE:%.+]] = linalg.tensor_reshape %arg0 [#[[$MAP0]]] + %0 = "tosa.reshape"(%arg0) {new_shape = [6]} : (tensor<2x3xf32>) -> tensor<6xf32> + // CHECK: return [[RESHAPE]] + return %0 : tensor<6xf32> +} + +// ----- + +// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-LABEL: @test_reshape_uprank +func @test_reshape_uprank(%arg0: tensor<6xf32>) -> tensor<2x3xf32> { + // CHECK: [[RESHAPE:%.+]] = linalg.tensor_reshape %arg0 [#[[$MAP0]]] + %0 = "tosa.reshape"(%arg0) {new_shape = [2, 3]} : (tensor<6xf32>) -> tensor<2x3xf32> + // CHECK: return [[RESHAPE]] + return %0 : tensor<2x3xf32> +} + +// ----- + +// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-LABEL: @test_reshape_samerank +func @test_reshape_samerank(%arg0: tensor<3x2xf32>) -> tensor<2x3xf32> { + // CHECK: [[RESHAPE1:%.+]] = linalg.tensor_reshape %arg0 [#[[$MAP0]]] + // CHECK: [[RESHAPE2:%.+]] = linalg.tensor_reshape [[RESHAPE1]] [#[[$MAP0]]] + %0 = "tosa.reshape"(%arg0) {new_shape = [2, 3]} : (tensor<3x2xf32>) -> tensor<2x3xf32> + // CHECK: return [[RESHAPE2]] + return %0 : tensor<2x3xf32> +} + +// ----- + +// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2)> +// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3)> +// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)> + +// CHECK-LABEL: @test_reshape_downrank_6D +func @test_reshape_downrank_6D(%arg0: tensor<1x2x3x5x7x11xf32>) -> tensor<6x5x77xf32> { + // CHECK: linalg.tensor_reshape %arg0 [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]] + %0 = "tosa.reshape"(%arg0) {new_shape = [2, 3]} : (tensor<1x2x3x5x7x11xf32>) -> tensor<6x5x77xf32> + return %0 : tensor<6x5x77xf32> +}