diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -657,6 +657,56 @@ } }; +class ReverseConverter : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tosa::ReverseOp op, + PatternRewriter &rewriter) const final { + auto loc = op.getLoc(); + Value input = op.input(); + auto inputTy = input.getType().template cast(); + auto resultTy = op.getType().template cast(); + auto rank = resultTy.getRank(); + auto axis = op.axis(); + + if (!inputTy.hasStaticShape()) + return rewriter.notifyMatchFailure( + op, "No initial value found for reduction operation"); + + // First fill the output buffer with the init value. + auto initTensor = rewriter + .create( + loc, ArrayRef({}), inputTy.getShape(), + inputTy.getElementType()) + .result(); + + SmallVector inputExprs; + inputExprs.resize(resultTy.getRank()); + + for (int i = 0; i < rank; i++) + inputExprs[i] = rewriter.getAffineDimExpr(i); + + inputExprs[axis] = + rewriter.getAffineConstantExpr(inputTy.getDimSize(axis) - 1) - + inputExprs[axis]; + + SmallVector affineMaps = { + AffineMap::get(resultTy.getRank(), /*symbolCount=*/0, inputExprs, + rewriter.getContext()), + rewriter.getMultiDimIdentityMap(resultTy.getRank())}; + + rewriter.replaceOpWithNewOp( + op, resultTy, op.input(), ValueRange{initTensor}, affineMaps, + getNParallelLoopsAttrs(resultTy.getRank()), + [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { + nestedBuilder.create(op.getLoc(), *args.begin()); + }); + + return success(); + } +}; + } // namespace void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns( @@ -680,6 +730,6 @@ IdentityNConverter, IdentityNConverter, ReduceConverter, ReduceConverter, ReduceConverter, - ReduceConverter, ReshapeOpConverter, + ReduceConverter, ReshapeOpConverter, ReverseConverter, TransposeConverter>(context); } diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -433,3 +433,26 @@ %4 = "tosa.reduce_max"(%arg0) {axis = 0 : i64} : (tensor<5x4xi32>) -> tensor<4xi32> return } + +// ----- + +// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (-d0 + 4, d1)> +// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0, -d1 + 3)> + +// CHECK-LABEL: @reverse +func @reverse(%arg0: tensor<5x4xi32>) -> () { + // CHECK: [[INIT:%.+]] = linalg.init_tensor [5, 4] + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<5x4xi32>) outs([[INIT]] : tensor<5x4xi32>) { + // CHECK: ^bb0(%arg1: i32, %arg2: i32): + // CHECK: linalg.yield %arg1 : i32 + %0 = "tosa.reverse"(%arg0) {axis = 0 : i64} : (tensor<5x4xi32>) -> tensor<5x4xi32> + + // CHECK: [[INIT:%.+]] = linalg.init_tensor [5, 4] + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<5x4xi32>) outs([[INIT]] : tensor<5x4xi32>) { + // CHECK: ^bb0(%arg1: i32, %arg2: i32): + // CHECK: linalg.yield %arg1 : i32 + %1 = "tosa.reverse"(%arg0) {axis = 1 : i64} : (tensor<5x4xi32>) -> tensor<5x4xi32> + + return +}