diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -1510,6 +1510,59 @@ } }; +class GatherConverter : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(tosa::GatherOp op, ArrayRef args, + ConversionPatternRewriter &rewriter) const final { + auto input = args[0]; + auto indices = args[1]; + + auto inputTy = input.getType().cast(); + auto indicesTy = indices.getType().cast(); + auto resultTy = op.getType().cast(); + + if (!inputTy.hasStaticShape() || !indicesTy.hasStaticShape()) + return rewriter.notifyMatchFailure( + op, "require input type to have static shape"); + + auto resultElementTy = resultTy.getElementType(); + + auto loc = op.getLoc(); + + auto initTensor = + rewriter + .create(loc, ArrayRef{}, + resultTy.getShape(), resultElementTy) + .result(); + + SmallVector affineMaps = { + AffineMap::get( + /*dimCount=*/resultTy.getRank(), /*symbolCount=*/0, + {rewriter.getAffineDimExpr(0), rewriter.getAffineDimExpr(1)}, + rewriter.getContext()), + rewriter.getMultiDimIdentityMap(resultTy.getRank())}; + + auto genericOp = rewriter.create( + loc, ArrayRef({resultTy}), ValueRange{indices}, + ValueRange{initTensor}, affineMaps, + getNParallelLoopsAttrs(resultTy.getRank()), + [&](OpBuilder &b, Location loc, ValueRange indices, ValueRange args) { + auto indexValue = args[0]; + auto index0 = indices[0]; + Value index1 = rewriter.create( + loc, rewriter.getIndexType(), indexValue); + auto index2 = indices[2]; + Value extract = rewriter.create( + loc, input, ValueRange{index0, index1, index2}); + rewriter.create(loc, extract); + }); + rewriter.replaceOp(op, genericOp.getResult(0)); + return success(); + } +}; + // Lowerings the TableOp to a series of gathers and numerica operations. This // includes interpolation between the high/low values. For the I8 varient, this // simplifies to a single gather operation. @@ -1806,6 +1859,7 @@ TransposeConverter, MatMulConverter, MaxPool2dConverter, - FullyConnectedConverter>(patterns->getContext()); - // clang-format on + FullyConnectedConverter, + GatherConverter>(patterns->getContext()); + // clang-format on } diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -833,6 +833,32 @@ // ----- +// CHECK-LABEL: @gather_float +func @gather_float(%arg0: tensor<2x3x2xf32>, %arg1: tensor<2x3xi32>) -> () { + // CHECK: %[[INIT:.+]] = linalg.init_tensor [2, 3, 2] + // CHECK: %[[GENERIC:.+]] = linalg.indexed_generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg1 : tensor<2x3xi32>) outs(%[[INIT]] : tensor<2x3x2xf32>) + // CHECK: ^bb0(%[[IDX0:.+]]: index, %[[IDX1:.+]]: index, %[[IDX2:.+]]: index, %[[ARG0:.+]]: i32, %[[ARG1:.+]]: f32) + // CHECK: %[[CAST:.+]] = index_cast %[[ARG0]] + // CHECK: %[[EXTRACT:.+]] = tensor.extract %arg0[%[[IDX0]], %[[CAST]], %[[IDX2]]] : tensor<2x3x2xf32> + // CHECK: linalg.yield %[[EXTRACT]] + %0 = "tosa.gather"(%arg0, %arg1) : (tensor<2x3x2xf32>, tensor<2x3xi32>) -> (tensor<2x3x2xf32>) + return +} + +// CHECK-LABEL: @gather_int +func @gather_int(%arg0: tensor<2x3x2xi32>, %arg1: tensor<2x3xi32>) -> () { + // CHECK: %[[INIT:.+]] = linalg.init_tensor [2, 3, 2] + // CHECK: %[[GENERIC:.+]] = linalg.indexed_generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg1 : tensor<2x3xi32>) outs(%[[INIT]] : tensor<2x3x2xi32>) + // CHECK: ^bb0(%[[IDX0:.+]]: index, %[[IDX1:.+]]: index, %[[IDX2:.+]]: index, %[[ARG0:.+]]: i32, %[[ARG1:.+]]: i32) + // CHECK: %[[CAST:.+]] = index_cast %[[ARG0]] + // CHECK: %[[EXTRACT:.+]] = tensor.extract %arg0[%[[IDX0]], %[[CAST]], %[[IDX2]]] : tensor<2x3x2xi32> + // CHECK: linalg.yield %[[EXTRACT]] + %0 = "tosa.gather"(%arg0, %arg1) : (tensor<2x3x2xi32>, tensor<2x3xi32>) -> (tensor<2x3x2xi32>) + return +} + +// ----- + // CHECK-LABEL: @table8 func @table8(%arg0: tensor<6xi8>, %arg1: tensor<513xi8>) -> () { // CHECK: %[[INIT:.+]] = linalg.init_tensor [6]