diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -856,6 +856,48 @@ } }; +struct TileConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(tosa::TileOp op, ArrayRef args, + ConversionPatternRewriter &rewriter) const override { + auto input = op.input1(); + auto inputTy = input.getType().cast(); + auto inputShape = inputTy.getShape(); + auto resultTy = op.getType().cast(); + auto resultShape = resultTy.getShape(); + int64_t rank = inputTy.getRank(); + + if (!inputTy.hasStaticShape() || !resultTy.hasStaticShape()) + return failure(); + + SmallVector inputExprs; + inputExprs.reserve(rank); + + for (int64_t i = 0; i < rank; i++) + inputExprs.push_back(rewriter.getAffineDimExpr(i) % + rewriter.getAffineConstantExpr(inputShape[i])); + + auto initTensor = rewriter.create( + op.getLoc(), ArrayRef({}), resultShape, + resultTy.getElementType()); + + SmallVector affineMaps = { + AffineMap::get(resultTy.getRank(), /*symbolCount=*/0, inputExprs, + rewriter.getContext()), + rewriter.getMultiDimIdentityMap(resultTy.getRank())}; + + rewriter.replaceOpWithNewOp( + op, resultTy, op.input1(), ValueRange{initTensor}, affineMaps, + getNParallelLoopsAttrs(resultTy.getRank()), + [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { + nestedBuilder.create(op.getLoc(), *args.begin()); + }); + return success(); + } +}; + } // namespace void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns( @@ -880,5 +922,6 @@ IdentityNConverter, ReduceConverter, ReduceConverter, ReduceConverter, ReduceConverter, ConcatOpConversion, - ReshapeOpConverter, TransposeConverter, RescaleOpConverter>(context); + ReshapeOpConverter, TileConverter, TransposeConverter, + RescaleOpConverter>(context); } diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -524,3 +524,25 @@ %0 = "tosa.rescale"(%arg0) {input_zp = 243 : i32, output_zp = 252 : i32, multiplier = [19689 : i32], shift = [15 : i32], scale32 = true, double_round = true, per_channel = false} : (tensor<1xi8>) -> (tensor<1xi8>) return %0 : tensor<1xi8> } + +// ----- + +// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0 mod 2, d1 mod 3)> +// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0, d1)> + +// CHECK-LABEL: @tile +func @tile(%arg0 : tensor<2x3xi8>) -> () { + // CHECK: [[INIT:%.+]] = linalg.init_tensor [4, 3] + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<2x3xi8>) outs([[INIT]] : tensor<4x3xi8>) + %0 = "tosa.tile"(%arg0) {multiples = [2, 1]} : (tensor<2x3xi8>) -> (tensor<4x3xi8>) + + // CHECK: linalg.init_tensor + // CHECK: linalg.generic + %1 = "tosa.tile"(%arg0) {multiples = [1, 2]} : (tensor<2x3xi8>) -> (tensor<2x6xi8>) + + // CHECK: linalg.init_tensor + // CHECK: linalg.generic + %2 = "tosa.tile"(%arg0) {multiples = [5, 7]} : (tensor<2x3xi8>) -> (tensor<10x21xi8>) + + return +}