diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -380,6 +380,15 @@ return rewriter.getIntegerAttr( elementTy, APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth())); + if (isa(op) && elementTy.isa()) + return rewriter.getFloatAttr( + elementTy, APFloat::getLargest( + elementTy.cast().getFloatSemantics(), true)); + + if (isa(op) && elementTy.isa()) + return rewriter.getIntegerAttr( + elementTy, APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth())); + return {}; } @@ -856,6 +865,131 @@ } }; +// Tosa argmax lowering represents the ArgMax op as an linalg.indexed_generic +// op, producing two output buffers. +// +// The first output buffer contains the index of the found maximum value. It is +// initialized to 0 and is resulting integer type. +// +// The second output buffer contains the maximum value found. It is initialized +// to the minimum representable value of the input element type. After being +// populated by indexed_generic, this buffer is disgarded as only the index is +// requested. +// +// The indexed_generic op updates both the maximum value and index if the +// current value exceeds the running max. +class ArgMaxConverter : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tosa::ArgMaxOp argmaxOp, + PatternRewriter &rewriter) const final { + auto loc = argmaxOp.getLoc(); + Value input = argmaxOp.input(); + auto inputTy = input.getType().cast(); + auto resultTy = argmaxOp.output().getType().cast(); + auto inElementTy = inputTy.getElementType(); + auto outElementTy = resultTy.getElementType(); + int axis = argmaxOp.axis(); + auto resultMaxTy = RankedTensorType::get(resultTy.getShape(), inElementTy); + + if (!inputTy.hasStaticShape()) + return rewriter.notifyMatchFailure( + argmaxOp, + "tosa.arg_max to linalg.* requires statically shaped input"); + + if (!outElementTy.isa()) + return rewriter.notifyMatchFailure( + argmaxOp, + "tosa.arg_max to linalg.* requires integer-like result type"); + + // First fill the output buffer for the index. + auto initTensorIdx = + rewriter + .create(loc, ArrayRef({}), + resultTy.getShape(), outElementTy) + .result(); + auto fillValueIdx = rewriter.create( + loc, rewriter.getIntegerAttr(outElementTy, 0)); + auto filledTensorIdx = + rewriter.create(loc, initTensorIdx, fillValueIdx) + .result(); + + // Second fill the output buffer for the running max. + auto initTensorMax = + rewriter + .create(loc, ArrayRef({}), + resultTy.getShape(), inElementTy) + .result(); + auto fillValueMaxAttr = + createInitialValueForReduceOp(argmaxOp, inElementTy, rewriter); + + if (!fillValueMaxAttr) + return rewriter.notifyMatchFailure( + argmaxOp, "unsupported tosa.argmax element type"); + + auto fillValueMax = rewriter.create(loc, fillValueMaxAttr); + auto filledTensorMax = + rewriter.create(loc, initTensorMax, fillValueMax) + .result(); + + // We need to reduce along the arg-max axis, with parallel operations along + // the rest. + SmallVector iteratorTypes; + iteratorTypes.resize(inputTy.getRank(), getParallelIteratorTypeName()); + iteratorTypes[axis] = getReductionIteratorTypeName(); + + SmallVector srcExprs; + SmallVector dstExprs; + for (int i = 0, rank = inputTy.getRank(); i != rank; ++i) { + srcExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext())); + if (axis != i) + dstExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext())); + } + + bool didEncounterError = false; + auto maps = AffineMap::inferFromExprList({srcExprs, dstExprs, dstExprs}); + auto linalgOp = rewriter.create( + loc, ArrayRef({resultTy, resultMaxTy}), input, + ValueRange({filledTensorIdx, filledTensorMax}), maps, iteratorTypes, + [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange ivs, + ValueRange blockArgs) { + auto newValue = blockArgs[0]; + auto oldIndex = blockArgs[1]; + auto oldValue = blockArgs[2]; + + Value newIndex = rewriter.create( + nestedLoc, oldIndex.getType(), ivs[axis]); + + Value predicate; + if (inElementTy.isa()) { + predicate = rewriter.create( + nestedLoc, CmpFPredicate::OGT, newValue, oldValue); + } else if (inElementTy.isa()) { + predicate = rewriter.create( + nestedLoc, CmpIPredicate::sgt, newValue, oldValue); + } else { + didEncounterError = true; + return; + } + + auto resultMax = rewriter.create(nestedLoc, predicate, + newValue, oldValue); + auto resultIndex = rewriter.create( + nestedLoc, predicate, newIndex, oldIndex); + nestedBuilder.create( + nestedLoc, ValueRange({resultIndex, resultMax})); + }); + + if (didEncounterError) + return rewriter.notifyMatchFailure( + argmaxOp, "unsupported tosa.argmax element type"); + + rewriter.replaceOp(argmaxOp, linalgOp.getResult(0)); + return success(); + } +}; + } // namespace void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns( @@ -879,6 +1013,6 @@ IdentityNConverter, IdentityNConverter, ReduceConverter, ReduceConverter, ReduceConverter, - ReduceConverter, ConcatOpConversion, + ReduceConverter, ArgMaxConverter, ConcatOpConversion, ReshapeOpConverter, TransposeConverter, RescaleOpConverter>(context); } diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -524,3 +524,51 @@ %0 = "tosa.rescale"(%arg0) {input_zp = 243 : i32, output_zp = 252 : i32, multiplier = [19689 : i32], shift = [15 : i32], scale32 = true, double_round = true, per_channel = false} : (tensor<1xi8>) -> (tensor<1xi8>) return %0 : tensor<1xi8> } + +// ----- + +// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d1)> +// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0)> +// CHECK: #[[$MAP3:.*]] = affine_map<(d0) -> (d0)> +// CHECK: #[[$MAP4:.*]] = affine_map<(d0) -> ()> + +func @argmax(%arg0 : tensor<3x2xi32>, %arg1 : tensor<6xf32>) -> () { + // CHECK: [[IDX_INIT:%.+]] = linalg.init_tensor [2] + // CHECK: [[IDX_MIN:%.+]] = constant 0 : i32 + // CHECK: [[IDX_FILL:%.+]] = linalg.fill([[IDX_INIT]], [[IDX_MIN]]) + // CHECK: [[VAL_INIT:%.+]] = linalg.init_tensor [2] + // CHECK: [[VAL_MIN:%.+]] = constant -2147483648 + // CHECK: [[VAL_FILL:%.+]] = linalg.fill([[VAL_INIT]], [[VAL_MIN]]) + // CHECK: linalg.indexed_generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["reduction", "parallel"]} ins(%arg0 : tensor<3x2xi32>) outs([[IDX_FILL]], [[VAL_FILL]] : tensor<2xi32>, tensor<2xi32>) + // CHECK: [[CAST:%.+]] = index_cast %arg2 + // CHECK: [[CMP:%.+]] = cmpi sgt, %arg4, %arg6 + // CHECK: [[SELECT_VAL:%.+]] = select [[CMP]], %arg4, %arg6 + // CHECK: [[SELECT_IDX:%.+]] = select [[CMP]], [[CAST]], %arg5 + // CHECK: linalg.yield [[SELECT_IDX]], [[SELECT_VAL]] + %0 = "tosa.argmax"(%arg0) { axis = 0 : i64} : (tensor<3x2xi32>) -> (tensor<2xi32>) + + // CHECK: [[IDX_INIT:%.+]] = linalg.init_tensor [3] + // CHECK: [[IDX_MIN:%.+]] = constant 0 : i32 + // CHECK: [[IDX_FILL:%.+]] = linalg.fill([[IDX_INIT]], [[IDX_MIN]]) + // CHECK: [[VAL_INIT:%.+]] = linalg.init_tensor [3] + // CHECK: [[VAL_MIN:%.+]] = constant -2147483648 + // CHECK: [[VAL_FILL:%.+]] = linalg.fill([[VAL_INIT]], [[VAL_MIN]]) + // CHECK: linalg.indexed_generic {indexing_maps = [#map0, #map2, #map2], iterator_types = ["parallel", "reduction"]} ins(%arg0 : tensor<3x2xi32>) outs([[IDX_FILL]], [[VAL_FILL]] : tensor<3xi32>, tensor<3xi32>) + // CHECK: [[CAST:%.+]] = index_cast %arg3 + // CHECK: [[CMP:%.+]] = cmpi sgt, %arg4, %arg6 + // CHECK: [[SELECT_VAL:%.+]] = select [[CMP]], %arg4, %arg6 + // CHECK: [[SELECT_IDX:%.+]] = select [[CMP]], [[CAST]], %arg5 + // CHECK: linalg.yield [[SELECT_IDX]], [[SELECT_VAL]] + %1 = "tosa.argmax"(%arg0) { axis = 1 : i64} : (tensor<3x2xi32>) -> (tensor<3xi32>) + + // CHECK: constant -3.40282347E+38 : f32 + // CHECK: index_cast + // CHECK: cmpf ogt + // CHECK: select + // CHECK: select + // CHECK: linalg.yield + %2 = "tosa.argmax"(%arg1) { axis = 0 : i64} : (tensor<6xf32>) -> (tensor) + + return +}