diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -269,15 +269,15 @@ SmallVector opResultTypes; SmallVector initTensors; for (auto result : results) { - auto resultType = result.getType().template cast(); - if (!resultType.hasStaticShape()) + auto resultTy = result.getType().template cast(); + if (!resultTy.hasStaticShape()) return rewriter.notifyMatchFailure( operation, "tosa to linalg conversion expects statically shaped tensors"); initTensors.push_back(rewriter.create( - loc, ArrayRef({}), resultType.getShape(), - resultType.getElementType())); + loc, ArrayRef({}), resultTy.getShape(), + resultTy.getElementType())); opResultTypes.push_back(result.getType()); } @@ -330,6 +330,149 @@ return success(); } +static Attribute createInitialValueForReduceOp(Operation *op, Type elementTy, + PatternRewriter &rewriter) { + if (isa(op) && elementTy.isa()) + return rewriter.getFloatAttr(elementTy, 0.0); + + if (isa(op) && elementTy.isa()) + return rewriter.getIntegerAttr(elementTy, 0); + + if (isa(op) && elementTy.isa()) + return rewriter.getFloatAttr(elementTy, 1.0); + + if (isa(op) && elementTy.isa()) + return rewriter.getIntegerAttr(elementTy, 1); + + if (isa(op) && elementTy.isa()) + return rewriter.getFloatAttr( + elementTy, APFloat::getLargest( + elementTy.cast().getFloatSemantics(), false)); + + if (isa(op) && elementTy.isa()) + return rewriter.getIntegerAttr( + elementTy, APInt::getSignedMaxValue(elementTy.getIntOrFloatBitWidth())); + + if (isa(op) && elementTy.isa()) + return rewriter.getFloatAttr( + elementTy, APFloat::getLargest( + elementTy.cast().getFloatSemantics(), true)); + + if (isa(op) && elementTy.isa()) + return rewriter.getIntegerAttr( + elementTy, APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth())); + + return {}; +} + +static Value createLinalgBodyCalculationForReduceOp(Operation *op, + ValueRange args, + Type elementTy, + PatternRewriter &rewriter) { + if (isa(op) && elementTy.isa()) { + return rewriter.create(op->getLoc(), args).getResult(); + } + + if (isa(op) && elementTy.isa()) { + return rewriter.create(op->getLoc(), args).getResult(); + } + + if (isa(op) && elementTy.isa()) { + return rewriter.create(op->getLoc(), args).getResult(); + } + + if (isa(op) && elementTy.isa()) { + return rewriter.create(op->getLoc(), args).getResult(); + } + + if (isa(op) && elementTy.isa()) { + auto predicate = rewriter.create( + op->getLoc(), CmpFPredicate::OLT, args[0], args[1]); + return rewriter.create(op->getLoc(), predicate, args[0], + args[1]); + } + + if (isa(op) && elementTy.isa()) { + auto predicate = rewriter.create( + op->getLoc(), CmpIPredicate::slt, args[0], args[1]); + return rewriter.create(op->getLoc(), predicate, args[0], + args[1]); + } + + if (isa(op) && elementTy.isa()) { + auto predicate = rewriter.create( + op->getLoc(), CmpFPredicate::OGT, args[0], args[1]); + return rewriter.create(op->getLoc(), predicate, args[0], + args[1]); + } + + if (isa(op) && elementTy.isa()) { + auto predicate = rewriter.create( + op->getLoc(), CmpIPredicate::sgt, args[0], args[1]); + return rewriter.create(op->getLoc(), predicate, args[0], + args[1]); + } + + return {}; +} + +static LogicalResult +elementwiseMatchAndRewriteHelper(Operation *op, uint64_t axis, + PatternRewriter &rewriter) { + auto loc = op->getLoc(); + auto inputTy = op->getOperand(0).getType().template cast(); + auto resultTy = op->getResult(0).getType().template cast(); + auto elementTy = resultTy.getElementType(); + Value input = op->getOperand(0); + + // First fill the output buffer with the init value. + auto initTensor = rewriter + .create(loc, ArrayRef({}), + resultTy.getShape(), + resultTy.getElementType()) + .result(); + + auto fillValueAttr = createInitialValueForReduceOp(op, elementTy, rewriter); + if (!fillValueAttr) + return rewriter.notifyMatchFailure( + op, "No initial value found for reduction operation"); + + auto fillValue = rewriter.create(loc, fillValueAttr); + auto filledTensor = + rewriter.create(loc, initTensor, fillValue).result(); + + SmallVector srcExprs; + SmallVector dstExprs; + SmallVector iteratorTypes; + for (unsigned int i = 0, rank = inputTy.getRank(); i != rank; ++i) { + srcExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext())); + + iteratorTypes.push_back(axis == i ? getReductionIteratorTypeName() + : getParallelIteratorTypeName()); + if (axis != i) + dstExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext())); + } + + bool didEncounterError = false; + auto maps = AffineMap::inferFromExprList({srcExprs, dstExprs}); + auto linalgOp = rewriter.create( + loc, resultTy, input, filledTensor, maps, iteratorTypes, + [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) { + auto result = createLinalgBodyCalculationForReduceOp( + op, blockArgs, elementTy, rewriter); + if (result) + didEncounterError = true; + + nestedBuilder.create(loc, result); + }); + + if (!didEncounterError) + return failure(); + + rewriter.replaceOp(op, linalgOp.getOperation()->getResults()); + return success(); +} + namespace { template @@ -500,6 +643,18 @@ } }; +template +class ReduceConverter : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(SrcOp reduceOp, + PatternRewriter &rewriter) const final { + return elementwiseMatchAndRewriteHelper(reduceOp.getOperation(), + reduceOp.axis(), rewriter); + } +}; + } // namespace void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns( @@ -521,6 +676,8 @@ PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, IdentityNConverter, - IdentityNConverter, - ReshapeOpConverter, TransposeConverter>(context); + IdentityNConverter, ReduceConverter, + ReduceConverter, ReduceConverter, + ReduceConverter, ReshapeOpConverter, + TransposeConverter>(context); } diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -335,3 +335,101 @@ %1 = "tosa.transpose"(%arg0, %0) : (tensor<1x2x3xi32>, tensor<3xi32>) -> (tensor<2x3x1xi32>) return } + +// ----- + +// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d1)> +// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0)> + +// CHECK-LABEL: @reduce_float +// CHECK-SAME: [[ARG0:%.+]]: tensor<5x4xf32> +func @reduce_float(%arg0: tensor<5x4xf32>) -> () { + // CHECK: [[INIT:%.+]] = linalg.init_tensor [4] + // CHECK: [[CST0:%.+]] = constant 0.0 + // CHECK: [[FILL:%.+]] = linalg.fill([[INIT]], [[CST0]]) + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["reduction", "parallel"]} ins([[ARG0]] : tensor<5x4xf32>) outs([[FILL]] : tensor<4xf32>) + // CHECK: ^bb0(%arg1: f32, %arg2: f32) + // CHECK: [[RES:%.+]] = addf %arg1, %arg2 : f32 + // CHECK: linalg.yield [[RES]] : f32 + %0 = "tosa.reduce_sum"(%arg0) {axis = 0 : i64} : (tensor<5x4xf32>) -> tensor<4xf32> + + // CHECK: [[INIT:%.+]] = linalg.init_tensor [5] + // CHECK: [[CST0:%.+]] = constant 0.0 + // CHECK: [[FILL:%.+]] = linalg.fill([[INIT]], [[CST0]]) + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP2]]], iterator_types = ["parallel", "reduction"]} ins([[ARG0]] : tensor<5x4xf32>) outs([[FILL]] : tensor<5xf32>) + // CHECK: ^bb0(%arg1: f32, %arg2: f32) + // CHECK: [[RES:%.+]] = addf %arg1, %arg2 : f32 + // CHECK: linalg.yield [[RES]] : f32 + %1 = "tosa.reduce_sum"(%arg0) {axis = 1 : i64} : (tensor<5x4xf32>) -> tensor<5xf32> + + // CHECK: constant 1.0 + // CHECK: linalg.fill + // CHECK: linalg.generic + // CHECK: mulf + %2 = "tosa.reduce_prod"(%arg0) {axis = 0 : i64} : (tensor<5x4xf32>) -> tensor<4xf32> + + // CHECK: constant 3.40282347E+38 : f32 + // CHECK: linalg.fill + // CHECK: linalg.generic + // CHECK: cmpf olt + // CHECK: select + %3 = "tosa.reduce_min"(%arg0) {axis = 0 : i64} : (tensor<5x4xf32>) -> tensor<4xf32> + + // CHECK: constant -3.40282347E+38 : f32 + // CHECK: linalg.fill + // CHECK: linalg.generic + // CHECK: cmpf ogt + // CHECK: select + %4 = "tosa.reduce_max"(%arg0) {axis = 0 : i64} : (tensor<5x4xf32>) -> tensor<4xf32> + return +} + +// ----- + +// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d1)> +// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0)> + +// CHECK-LABEL: @reduce_int +// CHECK-SAME: [[ARG0:%.+]]: tensor<5x4xi32> +func @reduce_int(%arg0: tensor<5x4xi32>) -> () { + // CHECK: [[INIT:%.+]] = linalg.init_tensor [4] + // CHECK: [[CST0:%.+]] = constant 0 + // CHECK: [[FILL:%.+]] = linalg.fill([[INIT]], [[CST0]]) + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["reduction", "parallel"]} ins([[ARG0]] : tensor<5x4xi32>) outs([[FILL]] : tensor<4xi32>) + // CHECK: ^bb0(%arg1: i32, %arg2: i32) + // CHECK: [[RES:%.+]] = addi %arg1, %arg2 : i32 + // CHECK: linalg.yield [[RES]] : i32 + %0 = "tosa.reduce_sum"(%arg0) {axis = 0 : i64} : (tensor<5x4xi32>) -> tensor<4xi32> + + // CHECK: [[INIT:%.+]] = linalg.init_tensor [5] + // CHECK: [[CST0:%.+]] = constant 0 + // CHECK: [[FILL:%.+]] = linalg.fill([[INIT]], [[CST0]]) + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP2]]], iterator_types = ["parallel", "reduction"]} ins([[ARG0]] : tensor<5x4xi32>) outs([[FILL]] : tensor<5xi32>) + // CHECK: ^bb0(%arg1: i32, %arg2: i32) + // CHECK: [[RES:%.+]] = addi %arg1, %arg2 : i32 + // CHECK: linalg.yield [[RES]] : i32 + %1 = "tosa.reduce_sum"(%arg0) {axis = 1 : i64} : (tensor<5x4xi32>) -> tensor<5xi32> + + // CHECK: constant 1 + // CHECK: linalg.fill + // CHECK: linalg.generic + // CHECK: muli + %2 = "tosa.reduce_prod"(%arg0) {axis = 0 : i64} : (tensor<5x4xi32>) -> tensor<4xi32> + + // CHECK: constant 2147483647 : i32 + // CHECK: linalg.fill + // CHECK: linalg.generic + // CHECK: cmpi slt + // CHECK: select + %3 = "tosa.reduce_min"(%arg0) {axis = 0 : i64} : (tensor<5x4xi32>) -> tensor<4xi32> + + // CHECK: constant -2147483648 : i32 + // CHECK: linalg.fill + // CHECK: linalg.generic + // CHECK: cmpi sgt + // CHECK: select + %4 = "tosa.reduce_max"(%arg0) {axis = 0 : i64} : (tensor<5x4xi32>) -> tensor<4xi32> + return +}