diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -517,6 +517,11 @@ ShapedType operandTy = operands.input1().getType().cast(); ShapedType resultTy = reshape.getType().template cast(); + if (operandTy == resultTy) { + rewriter.replaceOp(reshape, args[0]); + return success(); + } + if (!operandTy.hasStaticShape() || !resultTy.hasStaticShape()) return failure(); @@ -856,6 +861,71 @@ } }; +// This converter translate a tile operation to a reshape, broadcast, reshape. +// The first reshape minimally expands each tiled dimension to include a +// proceding size-1 dim. This dim is then broadcasted to the appropriate +// multiple. +struct TileConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(tosa::TileOp op, ArrayRef args, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto input = op.input1(); + auto inputTy = input.getType().cast(); + auto inputShape = inputTy.getShape(); + auto resultTy = op.getType().cast(); + auto elementTy = inputTy.getElementType(); + int64_t rank = inputTy.getRank(); + + if (!inputTy.hasStaticShape() || !resultTy.hasStaticShape()) + return failure(); + + SmallVector multiples; + getValuesFromIntArrayAttribute(op.multiples(), multiples); + + llvm::SmallVector reshapeShape; + reshapeShape.reserve(rank * 2); + for (int i = 0; i < rank; i++) { + reshapeShape.push_back(1); + reshapeShape.push_back(inputShape[i]); + } + + ShapedType reshapeTy = RankedTensorType::get(reshapeShape, elementTy); + Value reshape = rewriter.create( + loc, reshapeTy, input, rewriter.getI64ArrayAttr(reshapeTy.getShape())); + + // Determine how far to broadcast the newly expanded dimensions. + SmallVector genericShape; + for (int i = 0; i < rank; i++) { + genericShape.push_back(multiples[i]); + genericShape.push_back(inputShape[i]); + } + + auto initTensor = rewriter.create( + op.getLoc(), ArrayRef({}), genericShape, elementTy); + + SmallVector affineMaps = { + createAffineMapForType(reshapeTy, rewriter), + rewriter.getMultiDimIdentityMap(genericShape.size())}; + + auto genericOp = rewriter.create( + loc, RankedTensorType::get(genericShape, elementTy), reshape, + ValueRange{initTensor}, affineMaps, + getNParallelLoopsAttrs(genericShape.size()), + [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { + nestedBuilder.create(op.getLoc(), *args.begin()); + }); + + rewriter.replaceOpWithNewOp( + op, resultTy, genericOp.getResult(0), + rewriter.getI64ArrayAttr(resultTy.getShape())); + + return success(); + } +}; + } // namespace void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns( @@ -880,5 +950,6 @@ IdentityNConverter, ReduceConverter, ReduceConverter, ReduceConverter, ReduceConverter, ConcatOpConversion, - ReshapeOpConverter, TransposeConverter, RescaleOpConverter>(context); + ReshapeOpConverter, TileConverter, TransposeConverter, + RescaleOpConverter>(context); } diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -524,3 +524,38 @@ %0 = "tosa.rescale"(%arg0) {input_zp = 243 : i32, output_zp = 252 : i32, multiplier = [19689 : i32], shift = [15 : i32], scale32 = true, double_round = true, per_channel = false} : (tensor<1xi8>) -> (tensor<1xi8>) return %0 : tensor<1xi8> } + +// ----- + +// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> +// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d3)> +// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (0, d1, 0, d3)> +// CHECK: #[[$MAP3:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK: #[[$MAP4:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1)> +// CHECK: #[[$MAP5:.*]] = affine_map<(d0, d1, d2, d3) -> (d2, d3)> + +// CHECK-LABEL: @tile +func @tile(%arg0 : tensor<2x3xi8>) -> () { + // CHECK: [[RESHAPE:%.+]] = linalg.tensor_reshape %arg0 [#[[$MAP0]], #[[$MAP1]]] : tensor<2x3xi8> into tensor<1x2x1x3xi8> + // CHECK: [[INIT:%.+]] = linalg.init_tensor [2, 2, 1, 3] + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP3]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins([[RESHAPE]] : tensor<1x2x1x3xi8>) outs([[INIT]] : tensor<2x2x1x3xi8>) + // CHECK: linalg.yield %arg1 : i8 + // CHECK: linalg.tensor_reshape [[GENERIC]] [#[[$MAP0]], #[[$MAP1]]] + %0 = "tosa.tile"(%arg0) {multiples = [2, 1]} : (tensor<2x3xi8>) -> (tensor<4x3xi8>) + + // CHECK: [[RESHAPE:%.+]] = linalg.tensor_reshape %arg0 [#[[$MAP0]], #[[$MAP1]]] : tensor<2x3xi8> into tensor<1x2x1x3xi8> + // CHECK: [[INIT:%.+]] = linalg.init_tensor [1, 2, 2, 3] + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP3]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins([[RESHAPE]] : tensor<1x2x1x3xi8>) outs([[INIT]] : tensor<1x2x2x3xi8>) + // CHECK: linalg.yield %arg1 : i8 + // CHECK: linalg.tensor_reshape [[GENERIC]] [#[[$MAP4]], #[[$MAP5]]] + %1 = "tosa.tile"(%arg0) {multiples = [1, 2]} : (tensor<2x3xi8>) -> (tensor<2x6xi8>) + + // CHECK: [[RESHAPE:%.+]] = linalg.tensor_reshape %arg0 [#[[$MAP0]], #[[$MAP1]]] : tensor<2x3xi8> into tensor<1x2x1x3xi8> + // CHECK: [[INIT:%.+]] = linalg.init_tensor [5, 2, 7, 3] + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP3]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins([[RESHAPE]] : tensor<1x2x1x3xi8>) outs([[INIT]] : tensor<5x2x7x3xi8>) + // CHECK: linalg.yield %arg1 : i8 + // CHECK: linalg.tensor_reshape [[GENERIC]] [#[[$MAP4]], #[[$MAP5]]] + %2 = "tosa.tile"(%arg0) {multiples = [5, 7]} : (tensor<2x3xi8>) -> (tensor<10x21xi8>) + + return +}