diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -16,8 +16,11 @@ #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include + using namespace mlir; static SmallVector getNParallelLoopsAttrs(unsigned nParallelLoops) { @@ -339,6 +342,102 @@ } }; +class ReshapeOpConverter : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(tosa::ReshapeOp reshape, ArrayRef args, + ConversionPatternRewriter &rewriter) const final { + typename tosa::ReshapeOp::Adaptor operands(args); + + ShapedType operandTy = operands.input1().getType().cast(); + ShapedType resultTy = reshape.getType().template cast(); + + if (!operandTy.hasStaticShape() || !resultTy.hasStaticShape()) + return failure(); + + // Compute the reassociation maps for the linalg operation. + ArrayRef srcShape = + (operandTy.getRank() > resultTy.getRank() ? operandTy.getShape() + : resultTy.getShape()); + ArrayRef dstShape = + (operandTy.getRank() > resultTy.getRank() ? resultTy.getShape() + : operandTy.getShape()); + unsigned currSrcDim = 0, currDstDim = 0; + SmallVector reassociationMap( + dstShape.size()); + + // First scan all dimensions in the source shapes to see whether we have a + // perfect case where consecutive dimensions in source are collapsed. For + // such case we can just generate one single linalg.reshape. + bool isCollapsingSource = true; + while (currSrcDim < srcShape.size() && currDstDim < dstShape.size()) { + int64_t dstSize = dstShape[currDstDim]; + int64_t srcSize = srcShape[currSrcDim]; + while (srcSize < dstSize && currSrcDim < srcShape.size()) { + reassociationMap[currDstDim].push_back( + rewriter.getAffineDimExpr(currSrcDim++)); + srcSize *= srcShape[currSrcDim]; + } + if (srcSize == dstSize) { + reassociationMap[currDstDim].push_back( + rewriter.getAffineDimExpr(currSrcDim++)); + // If the next dim in dstShape is not 1, treat subsequent dims in + // srcShape which are 1 to be collapsed. + if (currDstDim == dstShape.size() - 1 || + dstShape[currDstDim + 1] != 1) { + while (currSrcDim < srcShape.size() && srcShape[currSrcDim] == 1) { + reassociationMap[currDstDim].push_back( + rewriter.getAffineDimExpr(currSrcDim++)); + } + } + } else { + isCollapsingSource = false; + break; + } + currDstDim++; + } + if (currSrcDim != srcShape.size() || currDstDim != dstShape.size()) + isCollapsingSource = false; + + // Otherwise, we need to first reduce all source dimensions into one and + // then expand to the destination dimensions. + if (!isCollapsingSource) { + auto getIdentityExprs = [&rewriter](int n) { + SmallVector exprs; + for (int i = 0; i < n; ++i) + exprs.push_back(rewriter.getAffineDimExpr(i)); + return exprs; + }; + Location loc = reshape.getLoc(); + int64_t totalElems = std::accumulate(srcShape.begin(), srcShape.end(), 1, + std::multiplies()); + auto elemTy = operandTy.getElementType(); + SmallVector collapsingMap = { + // Use operandTy here because we need to collapse all operands + // dimensions. + getIdentityExprs(operandTy.getShape().size())}; + SmallVector expandingMap = { + // Use resultTy here because we need to expand to all result + // dimensions. + getIdentityExprs(resultTy.getShape().size())}; + + auto collapsedTy = RankedTensorType::get({totalElems}, elemTy); + Value collapsedOp = rewriter.create( + loc, collapsedTy, args[0], collapsingMap); + rewriter.replaceOpWithNewOp( + reshape, resultTy, collapsedOp, expandingMap); + + return success(); + } + + rewriter.replaceOpWithNewOp( + reshape, resultTy, args[0], reassociationMap); + return success(); + } +}; + } // namespace void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns( @@ -358,6 +457,6 @@ PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, - PointwiseConverter, PointwiseConverter>( - context); + PointwiseConverter, PointwiseConverter, + ReshapeOpConverter>(context); } diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -258,3 +258,38 @@ return } +// ----- + +// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-LABEL: @test_reshape_downrank +func @test_reshape_downrank(%arg0: tensor<2x3xf32>) -> tensor<6xf32> { + // CHECK: [[RESHAPE:%.+]] = linalg.tensor_reshape %arg0 [#[[$MAP0]]] + %0 = "tosa.reshape"(%arg0) {new_shape = [6]} : (tensor<2x3xf32>) -> tensor<6xf32> + // CHECK: return [[RESHAPE]] + return %0 : tensor<6xf32> +} + +// ----- + +// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-LABEL: @test_reshape_uprank +func @test_reshape_uprank(%arg0: tensor<6xf32>) -> tensor<2x3xf32> { + // CHECK: [[RESHAPE:%.+]] = linalg.tensor_reshape %arg0 [#[[$MAP0]]] + %0 = "tosa.reshape"(%arg0) {new_shape = [2, 3]} : (tensor<6xf32>) -> tensor<2x3xf32> + // CHECK: return [[RESHAPE]] + return %0 : tensor<2x3xf32> +} + +// ----- + +// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-LABEL: @test_reshape_samerank +func @test_reshape_samerank(%arg0: tensor<3x2xf32>) -> tensor<2x3xf32> { + // CHECK: [[RESHAPE1:%.+]] = linalg.tensor_reshape %arg0 [#[[$MAP0]]] + // CHECK: [[RESHAPE2:%.+]] = linalg.tensor_reshape [[RESHAPE1]] [#[[$MAP0]]] + %0 = "tosa.reshape"(%arg0) {new_shape = [2, 3]} : (tensor<3x2xf32>) -> tensor<2x3xf32> + // CHECK: return [[RESHAPE2]] + return %0 : tensor<2x3xf32> +} + +