diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -585,7 +585,7 @@ } }; -class ReshapeOpConverter : public OpConversionPattern { +class ReshapeConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -727,7 +727,7 @@ } }; -class RescaleOpConverter : public OpRewritePattern { +class RescaleConverter : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; @@ -889,7 +889,7 @@ } }; -struct ConcatOpConversion : public OpConversionPattern { +struct ConcatConverter : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult @@ -936,6 +936,56 @@ } }; +class ReverseConverter : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tosa::ReverseOp op, + PatternRewriter &rewriter) const final { + auto loc = op.getLoc(); + Value input = op.input(); + auto inputTy = input.getType().template cast(); + auto resultTy = op.getType().template cast(); + auto rank = resultTy.getRank(); + auto axis = op.axis(); + + if (!inputTy.hasStaticShape()) + return rewriter.notifyMatchFailure( + op, "No initial value found for reduction operation"); + + // First fill the output buffer with the init value. + auto initTensor = rewriter + .create( + loc, ArrayRef({}), inputTy.getShape(), + inputTy.getElementType()) + .result(); + + SmallVector inputExprs; + inputExprs.resize(resultTy.getRank()); + + for (int i = 0; i < rank; i++) + inputExprs[i] = rewriter.getAffineDimExpr(i); + + inputExprs[axis] = + rewriter.getAffineConstantExpr(inputTy.getDimSize(axis) - 1) - + inputExprs[axis]; + + SmallVector affineMaps = { + AffineMap::get(resultTy.getRank(), /*symbolCount=*/0, inputExprs, + rewriter.getContext()), + rewriter.getMultiDimIdentityMap(resultTy.getRank())}; + + rewriter.replaceOpWithNewOp( + op, resultTy, op.input(), ValueRange{initTensor}, affineMaps, + getNParallelLoopsAttrs(resultTy.getRank()), + [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { + nestedBuilder.create(op.getLoc(), *args.begin()); + }); + + return success(); + } +}; + } // namespace void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns( @@ -963,6 +1013,6 @@ IdentityNConverter, IdentityNConverter, ReduceConverter, ReduceConverter, ReduceConverter, - ReduceConverter, ConcatOpConversion, - ReshapeOpConverter, TransposeConverter, RescaleOpConverter>(context); + ReduceConverter, ConcatConverter, ReshapeConverter, + RescaleConverter, ReverseConverter, TransposeConverter>(context); } diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -598,3 +598,26 @@ %0 = "tosa.rescale"(%arg0) {input_zp = 243 : i32, output_zp = 252 : i32, multiplier = [19689 : i32], shift = [15 : i32], scale32 = true, double_round = true, per_channel = false} : (tensor<1xi8>) -> (tensor<1xi8>) return %0 : tensor<1xi8> } + +// ----- + +// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (-d0 + 4, d1)> +// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0, -d1 + 3)> + +// CHECK-LABEL: @reverse +func @reverse(%arg0: tensor<5x4xi32>) -> () { + // CHECK: [[INIT:%.+]] = linalg.init_tensor [5, 4] + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<5x4xi32>) outs([[INIT]] : tensor<5x4xi32>) { + // CHECK: ^bb0(%arg1: i32, %arg2: i32): + // CHECK: linalg.yield %arg1 : i32 + %0 = "tosa.reverse"(%arg0) {axis = 0 : i64} : (tensor<5x4xi32>) -> tensor<5x4xi32> + + // CHECK: [[INIT:%.+]] = linalg.init_tensor [5, 4] + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<5x4xi32>) outs([[INIT]] : tensor<5x4xi32>) { + // CHECK: ^bb0(%arg1: i32, %arg2: i32): + // CHECK: linalg.yield %arg1 : i32 + %1 = "tosa.reverse"(%arg0) {axis = 1 : i64} : (tensor<5x4xi32>) -> tensor<5x4xi32> + + return +}