diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -702,6 +702,11 @@ ShapedType operandTy = operands.input1().getType().cast(); ShapedType resultTy = reshape.getType().template cast(); + if (operandTy == resultTy) { + rewriter.replaceOp(reshape, args[0]); + return success(); + } + if (!operandTy.hasStaticShape() || !resultTy.hasStaticShape()) return failure(); @@ -1086,6 +1091,70 @@ [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { nestedBuilder.create(op.getLoc(), *args.begin()); }); + return success(); + } +}; + +// This converter translate a tile operation to a reshape, broadcast, reshape. +// The first reshape minimally expands each tiled dimension to include a +// proceding size-1 dim. This dim is then broadcasted to the appropriate +// multiple. +struct TileConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(tosa::TileOp op, ArrayRef args, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto input = op.input1(); + auto inputTy = input.getType().cast(); + auto inputShape = inputTy.getShape(); + auto resultTy = op.getType().cast(); + auto elementTy = inputTy.getElementType(); + int64_t rank = inputTy.getRank(); + + if (!inputTy.hasStaticShape() || !resultTy.hasStaticShape()) + return failure(); + + SmallVector multiples; + getValuesFromIntArrayAttribute(op.multiples(), multiples); + + llvm::SmallVector reshapeShape; + reshapeShape.reserve(rank * 2); + for (int i = 0; i < rank; i++) { + reshapeShape.push_back(1); + reshapeShape.push_back(inputShape[i]); + } + + ShapedType reshapeTy = RankedTensorType::get(reshapeShape, elementTy); + Value reshape = rewriter.create( + loc, reshapeTy, input, rewriter.getI64ArrayAttr(reshapeTy.getShape())); + + // Broadcast the newly added dimensions to their appropriate multiple. + SmallVector genericShape; + for (int i = 0; i < rank; i++) { + genericShape.push_back(multiples[i]); + genericShape.push_back(inputShape[i]); + } + + auto initTensor = rewriter.create( + op.getLoc(), ArrayRef({}), genericShape, elementTy); + + SmallVector affineMaps = { + createAffineMapForType(reshapeTy, rewriter), + rewriter.getMultiDimIdentityMap(genericShape.size())}; + + auto genericOp = rewriter.create( + loc, RankedTensorType::get(genericShape, elementTy), reshape, + ValueRange{initTensor}, affineMaps, + getNParallelLoopsAttrs(genericShape.size()), + [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { + nestedBuilder.create(op.getLoc(), *args.begin()); + }); + + rewriter.replaceOpWithNewOp( + op, resultTy, genericOp.getResult(0), + rewriter.getI64ArrayAttr(resultTy.getShape())); return success(); } @@ -1119,6 +1188,6 @@ IdentityNConverter, ReduceConverter, ReduceConverter, ReduceConverter, ReduceConverter, ConcatConverter, ReshapeConverter, - RescaleConverter, ReverseConverter, TransposeConverter, MatMulConverter, - FullyConnectedConverter>(patterns->getContext()); + RescaleConverter, ReverseConverter, TileConverter, TransposeConverter, + MatMulConverter, FullyConnectedConverter>(patterns->getContext()); } diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -636,6 +636,40 @@ // CHECK: ^bb0(%arg1: i32, %arg2: i32): // CHECK: linalg.yield %arg1 : i32 %1 = "tosa.reverse"(%arg0) {axis = 1 : i64} : (tensor<5x4xi32>) -> tensor<5x4xi32> + return +} + +// ----- + +// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> +// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d3)> +// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (0, d1, 0, d3)> +// CHECK: #[[$MAP3:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK: #[[$MAP4:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1)> +// CHECK: #[[$MAP5:.*]] = affine_map<(d0, d1, d2, d3) -> (d2, d3)> + +// CHECK-LABEL: @tile +func @tile(%arg0 : tensor<2x3xi8>) -> () { + // CHECK: [[RESHAPE:%.+]] = linalg.tensor_reshape %arg0 [#[[$MAP0]], #[[$MAP1]]] : tensor<2x3xi8> into tensor<1x2x1x3xi8> + // CHECK: [[INIT:%.+]] = linalg.init_tensor [2, 2, 1, 3] + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP3]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins([[RESHAPE]] : tensor<1x2x1x3xi8>) outs([[INIT]] : tensor<2x2x1x3xi8>) + // CHECK: linalg.yield %arg1 : i8 + // CHECK: linalg.tensor_reshape [[GENERIC]] [#[[$MAP0]], #[[$MAP1]]] + %0 = "tosa.tile"(%arg0) {multiples = [2, 1]} : (tensor<2x3xi8>) -> (tensor<4x3xi8>) + + // CHECK: [[RESHAPE:%.+]] = linalg.tensor_reshape %arg0 [#[[$MAP0]], #[[$MAP1]]] : tensor<2x3xi8> into tensor<1x2x1x3xi8> + // CHECK: [[INIT:%.+]] = linalg.init_tensor [1, 2, 2, 3] + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP3]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins([[RESHAPE]] : tensor<1x2x1x3xi8>) outs([[INIT]] : tensor<1x2x2x3xi8>) + // CHECK: linalg.yield %arg1 : i8 + // CHECK: linalg.tensor_reshape [[GENERIC]] [#[[$MAP4]], #[[$MAP5]]] + %1 = "tosa.tile"(%arg0) {multiples = [1, 2]} : (tensor<2x3xi8>) -> (tensor<2x6xi8>) + + // CHECK: [[RESHAPE:%.+]] = linalg.tensor_reshape %arg0 [#[[$MAP0]], #[[$MAP1]]] : tensor<2x3xi8> into tensor<1x2x1x3xi8> + // CHECK: [[INIT:%.+]] = linalg.init_tensor [5, 2, 7, 3] + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP3]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins([[RESHAPE]] : tensor<1x2x1x3xi8>) outs([[INIT]] : tensor<5x2x7x3xi8>) + // CHECK: linalg.yield %arg1 : i8 + // CHECK: linalg.tensor_reshape [[GENERIC]] [#[[$MAP4]], #[[$MAP5]]] + %2 = "tosa.tile"(%arg0) {multiples = [5, 7]} : (tensor<2x3xi8>) -> (tensor<10x21xi8>) return }