diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -2486,6 +2486,52 @@ } }; +class ScatterConverter : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(tosa::ScatterOp op, ArrayRef args, + ConversionPatternRewriter &rewriter) const final { + auto initial = op.values_in(); + auto indices = op.indices(); + auto values = op.input(); + + ShapedType initialTy = initial.getType().cast(); + + AffineExpr m, n, o, p; + bindDims(rewriter.getContext(), m, n, o, p); + SmallVector affineMaps = AffineMap::inferFromExprList({ + ArrayRef({m, p}), + ArrayRef({m, p, o}), + ArrayRef({m, n, o}), + }); + + SmallVector iteratorTypes = { + getParallelIteratorTypeName(), getParallelIteratorTypeName(), + getParallelIteratorTypeName(), getReductionIteratorTypeName()}; + + auto loc = op.getLoc(); + auto genericOp = rewriter.create( + loc, ArrayRef({initialTy}), ValueRange{indices, values}, + ValueRange{initial}, affineMaps, iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + auto thisIndex = b.create(loc, 1); + auto indicesIndex = + b.create(loc, b.getIndexType(), args[0]); + + auto eq = rewriter.create(loc, CmpIPredicate::eq, + thisIndex, indicesIndex); + + auto select = + rewriter.create(loc, eq, args[1], args[2]); + b.create(loc, select.result()); + }); + + rewriter.replaceOp(op, genericOp.getResult(0)); + return success(); + } +}; + // Lowerings the TableOp to a series of gathers and numerica operations. This // includes interpolation between the high/low values. For the I8 varient, this // simplifies to a single gather operation. @@ -2934,6 +2980,7 @@ RescaleConverter, ResizeConverter, ReverseConverter, + ScatterConverter, TableConverter, TileConverter, TransposeConverter, diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -1165,6 +1165,20 @@ // ----- +// CHECK-LABEL: @scatter_float +func @scatter_float(%arg0: tensor<2x8x16xf32>, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3x16xf32>) -> (tensor<2x8x16xf32>) { + // CHECK: %0 = linalg.generic {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%arg1, %arg2 : tensor<2x3xi32>, tensor<2x3x16xf32>) outs(%arg0 : tensor<2x8x16xf32>) + // CHECK: %[[VAL1:.+]] = linalg.index 1 + // CHECK: %[[VAL2:.+]] = index_cast %arg3 + // CHECK: %[[VAL3:.+]] = cmpi eq, %[[VAL1]], %[[VAL2]] + // CHECK: %[[VAL4:.+]] = select %[[VAL3]], %arg4, %arg5 + // CHECK: linalg.yield %[[VAL4]] + %result = "tosa.scatter"(%arg0, %arg1, %arg2) : (tensor<2x8x16xf32>, tensor<2x3xi32>, tensor<2x3x16xf32>) -> tensor<2x8x16xf32> + return %result : tensor<2x8x16xf32> +} + +// ----- + // CHECK-LABEL: @table8 func @table8(%arg0: tensor<6xi8>, %arg1: tensor<512xi8>) -> () { // CHECK: %[[INIT:.+]] = linalg.init_tensor [6]