diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -807,133 +807,12 @@ return rewriter.notifyMatchFailure( op, "unable to create linalg.generic body for reduce op"); - SmallVector reassociationMap; - uint64_t expandInputRank = - linalgOp.getResults()[0].getType().cast().getRank(); - reassociationMap.resize(expandInputRank); - - for (uint64_t i = 0; i < expandInputRank; i++) { - int32_t dimToPush = i > axis ? i + 1 : i; - reassociationMap[i].push_back(rewriter.getAffineDimExpr(dimToPush)); - } - - if (expandInputRank != 0) { - int32_t expandedDim = axis < expandInputRank ? axis : expandInputRank - 1; - reassociationMap[expandedDim].push_back( - rewriter.getAffineDimExpr(expandedDim + 1)); - } - - rewriter.replaceOpWithNewOp( - op, resultTy, linalgOp.getResults()[0], reassociationMap); + rewriter.replaceOpWithNewOp( + op, resultTy, linalgOp.getResults()[0], + rewriter.getDenseI64ArrayAttr(resultTy.getShape())); return success(); } -static bool findIntermediateShape(ArrayRef lhsShape, - ArrayRef rhsShape, - SmallVector &intermediateShape, - bool isDynamic) { - if (isDynamic) { - // TODO (natashaknk): Make dynamic intermediate shape not always be rank-1 - intermediateShape = {ShapedType::kDynamic}; - return true; - } - - if (lhsShape.empty() || rhsShape.empty()) { - intermediateShape = {}; - return true; - } - - unsigned currLhsDim = 0, currRhsDim = 0; - while (currLhsDim < lhsShape.size() && currRhsDim < rhsShape.size()) { - int64_t rhsSize = rhsShape[currRhsDim]; - int64_t lhsSize = lhsShape[currLhsDim]; - while (lhsSize != rhsSize && currLhsDim < lhsShape.size() && - currRhsDim < rhsShape.size()) { - if (lhsSize < rhsSize) { - currLhsDim++; - if (currLhsDim < lhsShape.size()) { - lhsSize *= lhsShape[currLhsDim]; - } - } else { - currRhsDim++; - if (currRhsDim < rhsShape.size()) { - rhsSize *= rhsShape[currRhsDim]; - } - } - } - if (lhsSize == rhsSize) { - intermediateShape.push_back(lhsSize); - } - currRhsDim++; - currLhsDim++; - } - - // If the iterators didn't reach the end and their leftover dimensions are not - // equal to 1 an intermediate shape was not found. - while (currLhsDim < lhsShape.size()) { - if (lhsShape[currLhsDim++] != 1) { - return false; - } - } - - while (currRhsDim < rhsShape.size()) { - if (rhsShape[currRhsDim++] != 1) { - return false; - } - } - - return true; -} - -static bool createReassociationMapsForCollapse( - PatternRewriter &rewriter, ArrayRef srcShape, - ArrayRef dstShape, - SmallVector &reassociationMap, bool isDynamic) { - - // If the shape is dynamic, create a map for collapsing into one dimension. - if (isDynamic) { - SmallVector exprs; - for (int i = 0, s = srcShape.size(); i < s; ++i) - exprs.push_back(rewriter.getAffineDimExpr(i)); - reassociationMap = {exprs}; - return true; - } - - if (dstShape.empty()) { - reassociationMap = {}; - return true; - } - - reassociationMap.resize(dstShape.size()); - unsigned currSrcDim = 0, currDstDim = 0; - while (currSrcDim < srcShape.size() && currDstDim < dstShape.size()) { - int64_t dstSize = dstShape[currDstDim]; - int64_t srcSize = srcShape[currSrcDim]; - while (srcSize < dstSize && currSrcDim < srcShape.size()) { - reassociationMap[currDstDim].push_back( - rewriter.getAffineDimExpr(currSrcDim++)); - srcSize *= srcShape[currSrcDim]; - } - if (srcSize == dstSize) { - reassociationMap[currDstDim].push_back( - rewriter.getAffineDimExpr(currSrcDim++)); - // If the next dim in collapsedShape is not 1, treat subsequent dims in - // expandedShape which are 1 to be collapsed. - if (currDstDim == dstShape.size() - 1 || dstShape[currDstDim + 1] != 1) { - while (currSrcDim < srcShape.size() && srcShape[currSrcDim] == 1) { - reassociationMap[currDstDim].push_back( - rewriter.getAffineDimExpr(currSrcDim++)); - } - } - } - currDstDim++; - } - - // If both iterators didn't reach the end, we have leftover dimentions which - // implies that we have a mismatch in shape. - return currSrcDim == srcShape.size() && currDstDim == dstShape.size(); -} - namespace { template @@ -947,115 +826,6 @@ } }; -class ReshapeConverterCollapse : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const final { - ShapedType operandTy = adaptor.getInput1().getType().cast(); - ShapedType resultTy = reshape.getType().template cast(); - bool isDynamic = !operandTy.hasStaticShape(); - - if (isDynamic && resultTy.getRank() != 1) { - return rewriter.notifyMatchFailure( - reshape, "Cannot collapse dynamic dims to more than one dimension"); - } - - SmallVector reassociationMap; - if (!createReassociationMapsForCollapse(rewriter, operandTy.getShape(), - resultTy.getShape(), - reassociationMap, isDynamic)) { - return rewriter.notifyMatchFailure( - reshape, - "tosa.reshape Attempting to collapse into an incompatible shape"); - } - - SmallVector intermediateShape; - if (!findIntermediateShape(operandTy.getShape(), resultTy.getShape(), - intermediateShape, isDynamic)) { - return rewriter.notifyMatchFailure( - reshape, "tosa.reshape Cannot collapse into given shape"); - } - - rewriter.replaceOpWithNewOp( - reshape, resultTy, adaptor.getOperands()[0], reassociationMap); - return success(); - } -}; - -class ReshapeConverterExpand : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const final { - ShapedType operandTy = adaptor.getInput1().getType().cast(); - ShapedType resultTy = reshape.getType().template cast(); - bool isDynamic = !operandTy.hasStaticShape(); - - if (isDynamic && operandTy.getRank() != 1) { - return rewriter.notifyMatchFailure( - reshape, "Cannot expand dynamic dims from more than one dimension"); - } - - SmallVector reassociationMap; - if (!createReassociationMapsForCollapse(rewriter, resultTy.getShape(), - operandTy.getShape(), - reassociationMap, isDynamic)) { - return rewriter.notifyMatchFailure( - reshape, - "tosa.reshape Attempting to expand into an incompatible shape"); - } - - SmallVector intermediateShape; - if (!findIntermediateShape(operandTy.getShape(), resultTy.getShape(), - intermediateShape, isDynamic) || - intermediateShape != operandTy.getShape()) { - return rewriter.notifyMatchFailure( - reshape, "tosa.reshape Cannot expand into given shape"); - } - rewriter.replaceOpWithNewOp( - reshape, resultTy, adaptor.getOperands()[0], reassociationMap); - return success(); - } -}; - -class ReshapeConverterCollapseExpand - : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const final { - ShapedType operandTy = adaptor.getInput1().getType().cast(); - ShapedType resultTy = reshape.getType().template cast(); - bool isDynamic = !operandTy.hasStaticShape(); - - SmallVector intermediateShape; - if (!findIntermediateShape(resultTy.getShape(), operandTy.getShape(), - intermediateShape, isDynamic)) { - return rewriter.notifyMatchFailure( - reshape, "tosa.reshape Cannot identify an intermediate shape between " - "the given two shapes"); - } - - Value collapse = rewriter.create( - reshape.getLoc(), - RankedTensorType::get(intermediateShape, - reshape.getType().getElementType()), - adaptor.getInput1()); - Value expand = - rewriter.create(reshape.getLoc(), resultTy, collapse); - rewriter.replaceOp(reshape, expand); - - return success(); - } -}; - class TransposeConverter : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; @@ -2295,13 +2065,6 @@ patterns->add(patterns->getContext(), /*benefit=*/300); - patterns->add(patterns->getContext(), - /*benefit=*/100); - patterns->add(patterns->getContext(), - /*benefit=*/200); - patterns->add(patterns->getContext(), - /*benefit=*/300); - patterns->add< // clang-format off PointwiseConverter, diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp @@ -56,6 +56,7 @@ target.addLegalOp(); target.addLegalOp(); target.addLegalOp(); + target.addLegalOp(); target.addLegalOp(); target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp --- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp +++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp @@ -15,21 +15,236 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/IR/PatternMatch.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/DialectConversion.h" using namespace mlir; using namespace tosa; +static bool findIntermediateShape(ArrayRef lhsShape, + ArrayRef rhsShape, + SmallVector &intermediateShape, + bool isDynamic) { + if (isDynamic) { + // TODO (natashaknk): Make dynamic intermediate shape not always be rank-1 + intermediateShape = {ShapedType::kDynamic}; + return true; + } + + if (lhsShape.empty() || rhsShape.empty()) { + intermediateShape = {}; + return true; + } + + unsigned currLhsDim = 0, currRhsDim = 0; + while (currLhsDim < lhsShape.size() && currRhsDim < rhsShape.size()) { + int64_t rhsSize = rhsShape[currRhsDim]; + int64_t lhsSize = lhsShape[currLhsDim]; + while (lhsSize != rhsSize && currLhsDim < lhsShape.size() && + currRhsDim < rhsShape.size()) { + if (lhsSize < rhsSize) { + currLhsDim++; + if (currLhsDim < lhsShape.size()) { + lhsSize *= lhsShape[currLhsDim]; + } + } else { + currRhsDim++; + if (currRhsDim < rhsShape.size()) { + rhsSize *= rhsShape[currRhsDim]; + } + } + } + if (lhsSize == rhsSize) { + intermediateShape.push_back(lhsSize); + } + currRhsDim++; + currLhsDim++; + } + + // If the iterators didn't reach the end and their leftover dimensions are not + // equal to 1 an intermediate shape was not found. + while (currLhsDim < lhsShape.size()) { + if (lhsShape[currLhsDim++] != 1) { + return false; + } + } + + while (currRhsDim < rhsShape.size()) { + if (rhsShape[currRhsDim++] != 1) { + return false; + } + } + + return true; +} + +static bool createReassociationMapsForCollapse( + PatternRewriter &rewriter, ArrayRef srcShape, + ArrayRef dstShape, + SmallVector &reassociationMap, bool isDynamic) { + + // If the shape is dynamic, create a map for collapsing into one dimension. + if (isDynamic) { + SmallVector exprs; + for (int i = 0, s = srcShape.size(); i < s; ++i) + exprs.push_back(rewriter.getAffineDimExpr(i)); + reassociationMap = {exprs}; + return true; + } + + if (dstShape.empty()) { + reassociationMap = {}; + return true; + } + + reassociationMap.resize(dstShape.size()); + unsigned currSrcDim = 0, currDstDim = 0; + while (currSrcDim < srcShape.size() && currDstDim < dstShape.size()) { + int64_t dstSize = dstShape[currDstDim]; + int64_t srcSize = srcShape[currSrcDim]; + while (srcSize < dstSize && currSrcDim < srcShape.size()) { + reassociationMap[currDstDim].push_back( + rewriter.getAffineDimExpr(currSrcDim++)); + srcSize *= srcShape[currSrcDim]; + } + if (srcSize == dstSize) { + reassociationMap[currDstDim].push_back( + rewriter.getAffineDimExpr(currSrcDim++)); + // If the next dim in collapsedShape is not 1, treat subsequent dims in + // expandedShape which are 1 to be collapsed. + if (currDstDim == dstShape.size() - 1 || dstShape[currDstDim + 1] != 1) { + while (currSrcDim < srcShape.size() && srcShape[currSrcDim] == 1) { + reassociationMap[currDstDim].push_back( + rewriter.getAffineDimExpr(currSrcDim++)); + } + } + } + currDstDim++; + } + + // If both iterators didn't reach the end, we have leftover dimentions which + // implies that we have a mismatch in shape. + return currSrcDim == srcShape.size() && currDstDim == dstShape.size(); +} + namespace { +class ReshapeConverterCollapse : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; -class SliceConverter : public OpRewritePattern { + LogicalResult + matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final { + ShapedType operandTy = adaptor.getInput1().getType().cast(); + ShapedType resultTy = reshape.getType().template cast(); + bool isDynamic = !operandTy.hasStaticShape(); + + if (isDynamic && resultTy.getRank() != 1) { + return rewriter.notifyMatchFailure( + reshape, "Cannot collapse dynamic dims to more than one dimension"); + } + + SmallVector reassociationMap; + if (!createReassociationMapsForCollapse(rewriter, operandTy.getShape(), + resultTy.getShape(), + reassociationMap, isDynamic)) { + return rewriter.notifyMatchFailure( + reshape, + "tosa.reshape Attempting to collapse into an incompatible shape"); + } + + SmallVector intermediateShape; + if (!findIntermediateShape(operandTy.getShape(), resultTy.getShape(), + intermediateShape, isDynamic)) { + return rewriter.notifyMatchFailure( + reshape, "tosa.reshape Cannot collapse into given shape"); + } + + rewriter.replaceOpWithNewOp( + reshape, resultTy, adaptor.getOperands()[0], reassociationMap); + return success(); + } +}; + +class ReshapeConverterExpand : public OpConversionPattern { public: - using OpRewritePattern::OpRewritePattern; + using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, - PatternRewriter &rewriter) const final { + LogicalResult + matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final { + ShapedType operandTy = adaptor.getInput1().getType().cast(); + ShapedType resultTy = reshape.getType().template cast(); + bool isDynamic = !operandTy.hasStaticShape(); + + if (isDynamic && operandTy.getRank() != 1) { + return rewriter.notifyMatchFailure( + reshape, "Cannot expand dynamic dims from more than one dimension"); + } + + SmallVector reassociationMap; + if (!createReassociationMapsForCollapse(rewriter, resultTy.getShape(), + operandTy.getShape(), + reassociationMap, isDynamic)) { + return rewriter.notifyMatchFailure( + reshape, + "tosa.reshape Attempting to expand into an incompatible shape"); + } + + SmallVector intermediateShape; + if (!findIntermediateShape(operandTy.getShape(), resultTy.getShape(), + intermediateShape, isDynamic) || + intermediateShape != operandTy.getShape()) { + return rewriter.notifyMatchFailure( + reshape, "tosa.reshape Cannot expand into given shape"); + } + rewriter.replaceOpWithNewOp( + reshape, resultTy, adaptor.getOperands()[0], reassociationMap); + return success(); + } +}; + +class ReshapeConverterCollapseExpand + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final { + ShapedType operandTy = adaptor.getInput1().getType().cast(); + ShapedType resultTy = reshape.getType().template cast(); + bool isDynamic = !operandTy.hasStaticShape(); + + SmallVector intermediateShape; + if (!findIntermediateShape(resultTy.getShape(), operandTy.getShape(), + intermediateShape, isDynamic)) { + return rewriter.notifyMatchFailure( + reshape, "tosa.reshape Cannot identify an intermediate shape between " + "the given two shapes"); + } + + Value collapse = rewriter.create( + reshape.getLoc(), + RankedTensorType::get(intermediateShape, + reshape.getType().getElementType()), + adaptor.getInput1()); + Value expand = + rewriter.create(reshape.getLoc(), resultTy, collapse); + rewriter.replaceOp(reshape, expand); + + return success(); + } +}; + +class SliceConverter : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(tosa::SliceOp sliceOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final { Location loc = sliceOp.getLoc(); - Value input = sliceOp.getInput(); + Value input = adaptor.getInput(); SmallVector strides, sizes; ArrayRef starts = sliceOp.getStart(); strides.resize(sliceOp.getType().template cast().getRank(), 1); @@ -139,4 +354,10 @@ void mlir::tosa::populateTosaToTensorConversionPatterns( RewritePatternSet *patterns) { patterns->add(patterns->getContext()); + patterns->add(patterns->getContext(), + /*benefit=*/100); + patterns->add(patterns->getContext(), + /*benefit=*/200); + patterns->add(patterns->getContext(), + /*benefit=*/300); } diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -96,7 +96,7 @@ // CHECK-SAME: %[[ARG1:[0-9a-zA-Z_]*]]: tensor<2xf32> func.func @test_broadcast(%arg0: tensor<1xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> { // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<2xf32> - // CHECK: [[RESHAPE:%.+]] = tensor.collapse_shape %[[ARG0]] + // CHECK: [[RESHAPE:%.+]] = "tosa.reshape"(%[[ARG0]]) // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel"]} ins([[RESHAPE]], %[[ARG1]] : tensor, tensor<2xf32>) outs([[INIT]] : tensor<2xf32>) { // CHECK: ^bb0(%[[ARG2:.*]]: f32, %[[ARG3:.*]]: f32, %[[ARG4:.*]]: f32): // CHECK: [[ELEMENT:%.+]] = arith.addf %[[ARG2]], %[[ARG3]] : f32 @@ -116,7 +116,7 @@ // CHECK-SAME: %[[ARG1:[0-9a-zA-Z_]*]]: tensor<1xf32> func.func @test_broadcast_swapped_args(%arg0: tensor<2xf32>, %arg1: tensor<1xf32>) -> tensor<2xf32> { // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<2xf32> - // CHECK: [[RESHAPE:%.+]] = tensor.collapse_shape %[[ARG1]] + // CHECK: [[RESHAPE:%.+]] = "tosa.reshape"(%[[ARG1]]) // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]], [[RESHAPE]] : tensor<2xf32>, tensor) outs([[INIT]] : tensor<2xf32>) { // CHECK: ^bb0(%[[ARG2:.*]]: f32, %[[ARG3:.*]]: f32, %[[ARG4:.*]]: f32): // CHECK: [[ELEMENT:%.+]] = arith.addf %[[ARG2]], %[[ARG3]] : f32 @@ -137,8 +137,8 @@ // CHECK-SAME: %[[ARG1:[0-9a-zA-Z_]*]] func.func @test_multibroadcast(%arg0: tensor<1x3xf32>, %arg1: tensor<2x1xf32>) -> tensor<2x3xf32> { // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<2x3xf32> - // CHECK: [[RESHAPE1:%.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1]] - // CHECK: [[RESHAPE2:%.+]] = tensor.collapse_shape %[[ARG1]] {{\[}}[0, 1]] + // CHECK: [[RESHAPE1:%.+]] = "tosa.reshape"(%[[ARG0]]) {new_shape = array} + // CHECK: [[RESHAPE2:%.+]] = "tosa.reshape"(%[[ARG1]]) {new_shape = array} // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins([[RESHAPE1]], [[RESHAPE2]] : tensor<3xf32>, tensor<2xf32>) outs([[INIT]] : tensor<2x3xf32>) { // CHECK: ^bb0(%[[ARG2:.*]]: f32, %[[ARG3:.*]]: f32, %[[ARG4:.*]]: f32): // CHECK: [[ELEMENT:%.+]] = arith.addf %[[ARG2]], %[[ARG3]] : f32 @@ -536,94 +536,6 @@ return } -// ----- - -// CHECK-LABEL: @test_reshape_downrank -// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]] -func.func @test_reshape_downrank(%arg0: tensor<2x3xf32>) -> tensor<6xf32> { - // CHECK: [[RESHAPE:%.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1]] - %0 = "tosa.reshape"(%arg0) {new_shape = array} : (tensor<2x3xf32>) -> tensor<6xf32> - // CHECK: return [[RESHAPE]] - return %0 : tensor<6xf32> -} - -// ----- - -// CHECK-LABEL: @test_reshape_downrank_dyn -// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]] -func.func @test_reshape_downrank_dyn(%arg0: tensor<2x?xf32>) -> tensor { - // CHECK: [[RESHAPE:%.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1]] - %0 = "tosa.reshape"(%arg0) {new_shape = array} : (tensor<2x?xf32>) -> tensor - // CHECK: return [[RESHAPE]] - return %0 : tensor -} - -// ----- - -// CHECK-LABEL: @test_reshape_uprank -// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]] -func.func @test_reshape_uprank(%arg0: tensor<6xf32>) -> tensor<2x3xf32> { - // CHECK: [[RESHAPE:%.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1]] - %0 = "tosa.reshape"(%arg0) {new_shape = array} : (tensor<6xf32>) -> tensor<2x3xf32> - // CHECK: return [[RESHAPE]] - return %0 : tensor<2x3xf32> -} - -// ----- - -// CHECK-LABEL: @test_reshape_uprank_dyn -// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]] -func.func @test_reshape_uprank_dyn(%arg0: tensor) -> tensor<2x?xf32> { - // CHECK: [[RESHAPE:%.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1]] - %0 = "tosa.reshape"(%arg0) {new_shape = array} : (tensor) -> tensor<2x?xf32> - // CHECK: return [[RESHAPE]] - return %0 : tensor<2x?xf32> -} - -// ----- - -// CHECK-LABEL: @test_reshape_samerank -// CHECK-SAME: (%[[ARG0:.*]]: tensor<3x2xf32>) -func.func @test_reshape_samerank(%arg0: tensor<3x2xf32>) -> tensor<2x3xf32> { - // CHECK-NEXT: %[[RESHAPE1:.*]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1]] - // CHECK-NEXT: %[[RESHAPE2:.*]] = tensor.expand_shape %[[RESHAPE1]] {{\[}}[0, 1]] - %0 = "tosa.reshape"(%arg0) {new_shape = array} : (tensor<3x2xf32>) -> tensor<2x3xf32> - // CHECK-NEXT: return %[[RESHAPE2]] - return %0 : tensor<2x3xf32> -} - -// ----- - -// CHECK-LABEL: @test_reshape_samerank_dyn -// CHECK-SAME: (%[[ARG0:.*]]: tensor) -func.func @test_reshape_samerank_dyn(%arg0: tensor) -> tensor<2x?xf32> { - // CHECK-NEXT: %[[RESHAPE1:.*]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1]] - // CHECK-NEXT: %[[RESHAPE2:.*]] = tensor.expand_shape %[[RESHAPE1]] {{\[}}[0, 1]] - %0 = "tosa.reshape"(%arg0) {new_shape = array} : (tensor) -> tensor<2x?xf32> - // CHECK-NEXT: return %[[RESHAPE2]] - return %0 : tensor<2x?xf32> -} - -// ----- - -// CHECK-LABEL: @test_reshape_downrank_6D -// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]: -func.func @test_reshape_downrank_6D(%arg0: tensor<1x2x3x5x7x11xf32>) -> tensor<6x5x77xf32> { - // CHECK: tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2], [3], [4, 5]] - %0 = "tosa.reshape"(%arg0) {new_shape = array} : (tensor<1x2x3x5x7x11xf32>) -> tensor<6x5x77xf32> - return %0 : tensor<6x5x77xf32> -} - -// ----- - -// CHECK-LABEL: @test_reshape_downrank_6D_dyn -// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]: -func.func @test_reshape_downrank_6D_dyn(%arg0: tensor<1x2x?x5x7x11xf32>) -> tensor { - // CHECK: tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2, 3, 4, 5]] - // CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1, 2]] - %0 = "tosa.reshape"(%arg0) {new_shape = array} : (tensor<1x2x?x5x7x11xf32>) -> tensor - return %0 : tensor -} // ----- @@ -714,7 +626,7 @@ // CHECK: ^bb0(%[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32) // CHECK: [[RES:%.+]] = arith.addf %[[ARG1]], %[[ARG2]] : f32 // CHECK: linalg.yield [[RES]] : f32 - // CHECK: tensor.expand_shape [[GENERIC]] {{\[}}[0, 1]] : tensor<4xf32> into tensor<1x4xf32> + // CHECK: "tosa.reshape"([[GENERIC]]) {new_shape = array} %0 = "tosa.reduce_sum"(%arg0) {axis = 0 : i64} : (tensor<5x4xf32>) -> tensor<1x4xf32> // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<5xf32> @@ -724,7 +636,7 @@ // CHECK: ^bb0(%[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32) // CHECK: [[RES:%.+]] = arith.addf %[[ARG1]], %[[ARG2]] : f32 // CHECK: linalg.yield [[RES]] : f32 - // CHECK: tensor.expand_shape [[GENERIC]] {{\[}}[0, 1]] : tensor<5xf32> into tensor<5x1xf32> + // CHECK: "tosa.reshape"([[GENERIC]]) {new_shape = array} %1 = "tosa.reduce_sum"(%arg0) {axis = 1 : i64} : (tensor<5x4xf32>) -> tensor<5x1xf32> // CHECK: arith.constant 1.0 @@ -764,7 +676,7 @@ // CHECK: ^bb0(%[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32) // CHECK: %[[RES:.+]] = arith.addf %[[ARG1]], %[[ARG2]] : f32 // CHECK: linalg.yield %[[RES]] : f32 - // CHECK: tensor.expand_shape %[[GENERIC]] {{\[}}[0], [1, 2]] : tensor into tensor + // CHECK: "tosa.reshape"(%[[GENERIC]]) {new_shape = array} %0 = "tosa.reduce_sum"(%arg0) {axis = 1 : i64} : (tensor) -> tensor return } @@ -784,7 +696,7 @@ // CHECK: ^bb0(%[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32) // CHECK: %[[RES:.+]] = arith.addf %[[ARG1]], %[[ARG2]] : f32 // CHECK: linalg.yield %[[RES]] : f32 - // CHECK: tensor.expand_shape %[[GENERIC]] {{\[}}] : tensor into tensor<1xf32> + // CHECK: "tosa.reshape"(%[[GENERIC]]) {new_shape = array} : (tensor) -> tensor<1xf32> %0 = "tosa.reduce_sum"(%arg0) {axis = 0 : i64} : (tensor) -> tensor<1xf32> return } @@ -806,7 +718,7 @@ // CHECK: ^bb0(%[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32) // CHECK: %[[RES:.+]] = arith.mulf %[[ARG1]], %[[ARG2]] : f32 // CHECK: linalg.yield %[[RES]] : f32 - // CHECK: tensor.expand_shape %[[GENERIC]] {{\[}}[0], [1, 2]] : tensor<5x?xf32> into tensor<5x?x1xf32> + // CHECK: "tosa.reshape"(%[[GENERIC]]) {new_shape = array} %0 = "tosa.reduce_prod"(%arg0) {axis = 2 : i64} : (tensor<5x?x4xf32>) -> tensor<5x?x1xf32> return } @@ -828,7 +740,7 @@ // CHECK: ^bb0(%[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32) // CHECK: %[[MAX:.+]] = arith.maxf %[[ARG1]], %[[ARG2]] : f32 // CHECK: linalg.yield %[[MAX]] : f32 - // CHECK: tensor.expand_shape %[[GENERIC]] {{\[}}[0, 1]] : tensor into tensor + // CHECK: "tosa.reshape"(%[[GENERIC]]) {new_shape = array} %0 = "tosa.reduce_max"(%arg0) {axis = 1 : i64} : (tensor) -> tensor return } @@ -849,7 +761,7 @@ // CHECK: ^bb0(%[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32) // CHECK: [[RES:%.+]] = arith.addi %[[ARG1]], %[[ARG2]] : i32 // CHECK: linalg.yield [[RES]] : i32 - // CHECK: tensor.expand_shape [[GENERIC]] {{\[}}[0, 1]] : tensor<4xi32> into tensor<1x4xi32> + // CHECK: "tosa.reshape"([[GENERIC]]) {new_shape = array} %0 = "tosa.reduce_sum"(%arg0) {axis = 0 : i64} : (tensor<5x4xi32>) -> tensor<1x4xi32> // CHECK: [[INIT:%.+]] = tensor.empty() @@ -859,7 +771,7 @@ // CHECK: ^bb0(%[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32) // CHECK: [[RES:%.+]] = arith.addi %[[ARG1]], %[[ARG2]] : i32 // CHECK: linalg.yield [[RES]] : i32 - // CHECK: tensor.expand_shape [[GENERIC]] {{\[}}[0, 1]] : tensor<5xi32> into tensor<5x1xi32> + // CHECK: "tosa.reshape"([[GENERIC]]) {new_shape = array} %1 = "tosa.reduce_sum"(%arg0) {axis = 1 : i64} : (tensor<5x4xi32>) -> tensor<5x1xi32> // CHECK: arith.constant 1 @@ -899,7 +811,7 @@ // CHECK: ^bb0(%[[ARG1:[0-9a-zA-Z_]+]]: i1, %[[ARG2:[0-9a-zA-Z_]+]]: i1) // CHECK: [[RES:%.+]] = arith.andi %[[ARG1]], %[[ARG2]] : i1 // CHECK: linalg.yield [[RES]] : i1 - // CHECK: tensor.expand_shape [[GENERIC]] {{\[}}[0, 1]] : tensor<4xi1> into tensor<1x4xi1> + // CHECK: "tosa.reshape"([[GENERIC]]) {new_shape = array} %0 = "tosa.reduce_all"(%arg0) {axis = 0 : i64} : (tensor<5x4xi1>) -> tensor<1x4xi1> // CHECK: arith.constant false @@ -1231,21 +1143,21 @@ // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG0]] : tensor<2x3xi8>) outs([[INIT]] : tensor<2x2x1x3xi8>) // CHECK: ^bb0(%[[ARG1:[0-9a-zA-Z_]+]]: i8 // CHECK: linalg.yield %[[ARG1]] : i8 - // CHECK: tensor.collapse_shape [[GENERIC]] {{\[}}[0, 1, 2], [3]] + // CHECK: "tosa.reshape"([[GENERIC]]) {new_shape = array} %0 = "tosa.tile"(%arg0) {multiples = array} : (tensor<2x3xi8>) -> (tensor<4x3xi8>) // CHECK: [[INIT:%.+]] = tensor.empty() // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG0]] : tensor<2x3xi8>) outs([[INIT]] : tensor<1x2x2x3xi8>) // CHECK: ^bb0(%[[ARG1:[0-9a-zA-Z_]+]]: i8 // CHECK: linalg.yield %[[ARG1]] : i8 - // CHECK: tensor.collapse_shape [[GENERIC]] {{\[}}[0, 1], [2, 3]] + // CHECK: "tosa.reshape"([[GENERIC]]) {new_shape = array} %1 = "tosa.tile"(%arg0) {multiples = array} : (tensor<2x3xi8>) -> (tensor<2x6xi8>) // CHECK: [[INIT:%.+]] = tensor.empty() // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG0]] : tensor<2x3xi8>) outs([[INIT]] : tensor<5x2x7x3xi8>) // CHECK: ^bb0(%[[ARG1:[0-9a-zA-Z_]+]]: i8 // CHECK: linalg.yield %[[ARG1]] : i8 - // CHECK: tensor.collapse_shape [[GENERIC]] {{\[}}[0, 1], [2, 3]] + // CHECK: "tosa.reshape"([[GENERIC]]) {new_shape = array} %2 = "tosa.tile"(%arg0) {multiples = array} : (tensor<2x3xi8>) -> (tensor<10x21xi8>) return @@ -1265,8 +1177,7 @@ // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG0]] : tensor) outs(%[[INIT]] : tensor<2x?x1x3xi8>) // CHECK: ^bb0(%[[ARG1:.+]]: i8, // CHECK: linalg.yield %[[ARG1]] : i8 - // CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[GENERIC]] {{\[}}[0, 1, 2, 3]] - // CHECK: tensor.expand_shape %[[COLLAPSED]] {{\[}}[0, 1]] + // CHECK: "tosa.reshape"(%[[GENERIC]]) {new_shape = array} %0 = "tosa.tile"(%arg0) {multiples = array} : (tensor) -> (tensor) return @@ -1286,8 +1197,7 @@ // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG0]] : tensor<2x3xi8>) outs(%[[INIT]] : tensor<2x2x?x3xi8>) // CHECK: ^bb0(%[[ARG1:.+]]: i8, // CHECK: linalg.yield %[[ARG1]] : i8 - // CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[GENERIC]] {{\[}}[0, 1, 2, 3]] - // CHECK: tensor.expand_shape %[[COLLAPSED]] {{\[}}[0, 1]] + // CHECK: "tosa.reshape"(%[[GENERIC]]) {new_shape = array} %0 = "tosa.tile"(%arg0) {multiples = array} : (tensor<2x3xi8>) -> (tensor<2x?xi8>) return diff --git a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir --- a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir +++ b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir @@ -1,6 +1,95 @@ // RUN: mlir-opt --split-input-file --tosa-to-tensor %s -o -| FileCheck %s -// CHECK-LABEL: @slice +// CHECK-LABEL: @test_reshape_downrank +// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]] +func.func @test_reshape_downrank(%arg0: tensor<2x3xf32>) -> tensor<6xf32> { + // CHECK: [[RESHAPE:%.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1]] + %0 = "tosa.reshape"(%arg0) {new_shape = array} : (tensor<2x3xf32>) -> tensor<6xf32> + // CHECK: return [[RESHAPE]] + return %0 : tensor<6xf32> +} + +// ----- + +// CHECK-LABEL: @test_reshape_downrank_dyn +// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]] +func.func @test_reshape_downrank_dyn(%arg0: tensor<2x?xf32>) -> tensor { + // CHECK: [[RESHAPE:%.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1]] + %0 = "tosa.reshape"(%arg0) {new_shape = array} : (tensor<2x?xf32>) -> tensor + // CHECK: return [[RESHAPE]] + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: @test_reshape_uprank +// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]] +func.func @test_reshape_uprank(%arg0: tensor<6xf32>) -> tensor<2x3xf32> { + // CHECK: [[RESHAPE:%.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1]] + %0 = "tosa.reshape"(%arg0) {new_shape = array} : (tensor<6xf32>) -> tensor<2x3xf32> + // CHECK: return [[RESHAPE]] + return %0 : tensor<2x3xf32> +} + +// ----- + +// CHECK-LABEL: @test_reshape_uprank_dyn +// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]] +func.func @test_reshape_uprank_dyn(%arg0: tensor) -> tensor<2x?xf32> { + // CHECK: [[RESHAPE:%.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1]] + %0 = "tosa.reshape"(%arg0) {new_shape = array} : (tensor) -> tensor<2x?xf32> + // CHECK: return [[RESHAPE]] + return %0 : tensor<2x?xf32> +} + +// ----- + +// CHECK-LABEL: @test_reshape_samerank +// CHECK-SAME: (%[[ARG0:.*]]: tensor<3x2xf32>) +func.func @test_reshape_samerank(%arg0: tensor<3x2xf32>) -> tensor<2x3xf32> { + // CHECK-NEXT: %[[RESHAPE1:.*]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1]] + // CHECK-NEXT: %[[RESHAPE2:.*]] = tensor.expand_shape %[[RESHAPE1]] {{\[}}[0, 1]] + %0 = "tosa.reshape"(%arg0) {new_shape = array} : (tensor<3x2xf32>) -> tensor<2x3xf32> + // CHECK-NEXT: return %[[RESHAPE2]] + return %0 : tensor<2x3xf32> +} + +// ----- + +// CHECK-LABEL: @test_reshape_samerank_dyn +// CHECK-SAME: (%[[ARG0:.*]]: tensor) +func.func @test_reshape_samerank_dyn(%arg0: tensor) -> tensor<2x?xf32> { + // CHECK-NEXT: %[[RESHAPE1:.*]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1]] + // CHECK-NEXT: %[[RESHAPE2:.*]] = tensor.expand_shape %[[RESHAPE1]] {{\[}}[0, 1]] + %0 = "tosa.reshape"(%arg0) {new_shape = array} : (tensor) -> tensor<2x?xf32> + // CHECK-NEXT: return %[[RESHAPE2]] + return %0 : tensor<2x?xf32> +} + +// ----- + +// CHECK-LABEL: @test_reshape_downrank_6D +// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]: +func.func @test_reshape_downrank_6D(%arg0: tensor<1x2x3x5x7x11xf32>) -> tensor<6x5x77xf32> { + // CHECK: tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2], [3], [4, 5]] + %0 = "tosa.reshape"(%arg0) {new_shape = array} : (tensor<1x2x3x5x7x11xf32>) -> tensor<6x5x77xf32> + return %0 : tensor<6x5x77xf32> +} + +// ----- + +// CHECK-LABEL: @test_reshape_downrank_6D_dyn +// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]: +func.func @test_reshape_downrank_6D_dyn(%arg0: tensor<1x2x?x5x7x11xf32>) -> tensor { + // CHECK: tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2, 3, 4, 5]] + // CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1, 2]] + %0 = "tosa.reshape"(%arg0) {new_shape = array} : (tensor<1x2x?x5x7x11xf32>) -> tensor + return %0 : tensor +} + +// ----- + +// CHECK-LABLE: func @slice func.func @slice(%arg0: tensor<6xf32>) ->() { // CHECK: [[SLICE:%.+]] = tensor.extract_slice %arg0[2] [1] [1] %0 = "tosa.slice"(%arg0) {start = array, size = array} : (tensor<6xf32>) -> (tensor<1xf32>)