diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -1376,6 +1376,129 @@ } }; +// Lowerings the TableOp to a series of gathers and numerica operations. This +// includes interpolation between the high/low values. For the I8 varient, this +// simplifies to a single gather operation. +class TableConverter : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tosa::TableOp op, + PatternRewriter &rewriter) const final { + auto loc = op.getLoc(); + Value input = op.input(); + Value table = op.table(); + auto inputTy = input.getType().template cast(); + auto tableTy = table.getType().template cast(); + auto resultTy = op.getType().template cast(); + + auto inputElementTy = inputTy.getElementType(); + auto tableElementTy = tableTy.getElementType(); + auto resultElementTy = resultTy.getElementType(); + + if (!inputTy.hasStaticShape()) + return rewriter.notifyMatchFailure( + op, "No initial value found for table operation"); + + auto initTensor = + rewriter + .create(loc, ArrayRef({}), + resultTy.getShape(), resultElementTy) + .result(); + + SmallVector affineMaps = { + rewriter.getMultiDimIdentityMap(resultTy.getRank()), + rewriter.getMultiDimIdentityMap(resultTy.getRank())}; + + bool bodySucceeded = true; + + rewriter.replaceOpWithNewOp( + op, resultTy, ValueRange({input}), ValueRange{initTensor}, affineMaps, + getNParallelLoopsAttrs(resultTy.getRank()), + [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { + if (inputElementTy.isInteger(8) && tableElementTy.isInteger(8) && + resultElementTy.isInteger(8)) { + Value index = nestedBuilder.create( + nestedLoc, rewriter.getIndexType(), args[0]); + Value extract = nestedBuilder.create( + nestedLoc, table, ValueRange{index}); + nestedBuilder.create(nestedLoc, extract); + return; + } + + if (inputElementTy.isInteger(16) && tableElementTy.isInteger(16) && + resultElementTy.isInteger(32)) { + Value extend = rewriter.create( + nestedLoc, rewriter.getI32Type(), args[0]); + + auto offset = nestedBuilder.create( + nestedLoc, 32768, nestedBuilder.getI32Type()); + + auto seven = nestedBuilder.create( + nestedLoc, 7, nestedBuilder.getI32Type()); + + auto one = nestedBuilder.create( + nestedLoc, 1, nestedBuilder.getI32Type()); + auto b1111111 = nestedBuilder.create( + nestedLoc, 127, nestedBuilder.getI32Type()); + + // Compute the index and fractional part from the input value: + // value = value + 32768 + // index = value >> 7; + // fraction = 0x01111111 & value + auto extendAdd = + nestedBuilder.create(nestedLoc, extend, offset); + Value index = nestedBuilder.create( + nestedLoc, extendAdd, seven); + Value fraction = nestedBuilder.create( + nestedLoc, extendAdd, b1111111); + + // Extract the base and next values from the table. + // base = (int32_t) table[index]; + // next = (int32_t) table[index + 1]; + Value indexPlusOne = + nestedBuilder.create(nestedLoc, index, one); + + index = nestedBuilder.create( + nestedLoc, nestedBuilder.getIndexType(), index); + indexPlusOne = nestedBuilder.create( + nestedLoc, nestedBuilder.getIndexType(), indexPlusOne); + + Value base = nestedBuilder.create( + nestedLoc, table, ValueRange{index}); + Value next = nestedBuilder.create( + nestedLoc, table, ValueRange{indexPlusOne}); + + base = nestedBuilder.create( + nestedLoc, nestedBuilder.getI32Type(), base); + next = nestedBuilder.create( + nestedLoc, nestedBuilder.getI32Type(), next); + + // Use the fractional part to interpolate between the input values: + // result = (base << 7) + (next - base) * fraction + Value baseScaled = + nestedBuilder.create(nestedLoc, base, seven); + Value diff = nestedBuilder.create(nestedLoc, next, base); + Value diffScaled = + nestedBuilder.create(nestedLoc, diff, fraction); + Value result = + nestedBuilder.create(nestedLoc, baseScaled, diffScaled); + + nestedBuilder.create(nestedLoc, result); + return; + } + + bodySucceeded = false; + }); + + if (!bodySucceeded) + return rewriter.notifyMatchFailure( + op, "unable to create body for tosa.table op."); + + return success(); + } +}; + } // namespace void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns( @@ -1404,8 +1527,8 @@ IdentityNConverter, IdentityNConverter, ReduceConverter, ReduceConverter, ReduceConverter, - ReduceConverter, ArgMaxConverter, ConcatConverter, PadConverter, - ReshapeConverter, RescaleConverter, ReverseConverter, TileConverter, - TransposeConverter, MatMulConverter, FullyConnectedConverter>( - patterns->getContext()); + ReduceConverter, ArgMaxConverter, ConcatConverter, + PadConverter, ReshapeConverter, RescaleConverter, ReverseConverter, + TableConverter, TileConverter, TransposeConverter, MatMulConverter, + FullyConnectedConverter>(patterns->getContext()); } diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -793,3 +793,47 @@ return } + +// ----- + +// CHECK-LABEL: @table8 +func @table8(%arg0: tensor<6xi8>, %arg1: tensor<513xi8>) -> () { + // CHECK: [[INIT:%.+]] = linalg.init_tensor [6] + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%arg0 : tensor<6xi8>) outs([[INIT]] : tensor<6xi8>) + // CHECK: ^bb0([[ARG_IN:%.+]]: i8, [[ARG_INIT:%.+]]: i8) + // CHECK: [[CAST:%.+]] = index_cast [[ARG_IN]] + // CHECK: [[EXTRACT:%.+]] = tensor.extract %arg1{{\[}}[[CAST]]] + // CHECK: linalg.yield [[EXTRACT]] + %0 = "tosa.table"(%arg0, %arg1) : (tensor<6xi8>, tensor<513xi8>) -> (tensor<6xi8>) + return +} + +// CHECK-LABEL: @table16 +func @table16(%arg0: tensor<6xi16>, %arg1: tensor<513xi16>) -> () { + // CHECK: [[INIT:%.+]] = linalg.init_tensor [6] + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%arg0 : tensor<6xi16>) outs(%0 : tensor<6xi32>) + // CHECK: ^bb0(%arg2: i16, %arg3: i32) + // CHECK: [[EXT_IN:%.+]] = sexti %arg2 + // CHECK: [[C32768:%.+]] = constant 32768 + // CHECK: [[C7:%.+]] = constant 7 + // CHECK: [[C1:%.+]] = constant 1 + // CHECK: [[C127:%.+]] = constant 127 + // CHECK: [[INADD:%.+]] = addi [[EXT_IN]], [[C32768]] + // CHECK: [[IDX:%.+]] = shift_right_unsigned [[INADD]], [[C7]] + // CHECK: [[FRACTION:%.+]] = and [[INADD]], [[C127]] + // CHECK: [[IDXPLUS1:%.+]] = addi [[IDX]], [[C1]] + // CHECK: [[IDX_CAST:%.+]] = index_cast [[IDX]] + // CHECK: [[IDXPLUS1_CAST:%.+]] = index_cast [[IDXPLUS1]] + // CHECK: [[BASE:%.+]] = tensor.extract %arg1{{\[}}[[IDX_CAST]]] + // CHECK: [[NEXT:%.+]] = tensor.extract %arg1{{\[}}[[IDXPLUS1_CAST]]] + // CHECK: [[BASE_EXT:%.+]] = sexti [[BASE]] + // CHECK: [[NEXT_EXT:%.+]] = sexti [[NEXT]] + // CHECK: [[BASE_MUL:%.+]] = shift_left [[BASE_EXT]], [[C7]] + // CHECK: [[DIFF:%.+]] = subi [[NEXT_EXT]], [[BASE_EXT]] + // CHECK: [[DIFF_MUL:%.+]] = muli [[DIFF]], [[FRACTION]] + // CHECK: [[RESULT:%.+]] = addi [[BASE_MUL]], [[DIFF_MUL]] + // CHECK: linalg.yield [[RESULT]] + %0 = "tosa.table"(%arg0, %arg1) : (tensor<6xi16>, tensor<513xi16>) -> (tensor<6xi32>) + return +} +