Index: mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td =================================================================== --- mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -1120,14 +1120,16 @@ }]; let arguments = (ins - I1Tensor:$input1, - Tosa_Tensor:$input2, - Tosa_Tensor:$input3 + I1Tensor:$pred, + Tosa_Tensor:$on_true, + Tosa_Tensor:$on_false ); let results = (outs Tosa_Tensor:$output ); + let hasCanonicalizer = 1; + let hasFolder = 1; } //===----------------------------------------------------------------------===// Index: mlir/lib/Dialect/Tosa/IR/TosaOps.cpp =================================================================== --- mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -192,6 +192,27 @@ results.add(context); } +struct LogicalNotPredSelectOptimization + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tosa::SelectOp selectOp, + PatternRewriter &rewriter) const override { + auto notOp = selectOp.pred().getDefiningOp(); + if (!notOp) + return failure(); + std::array newOperands = {notOp.input1(), selectOp.on_false(), + selectOp.on_true()}; + selectOp.getOperation()->setOperands(newOperands); + return success(); + } +}; + +void tosa::SelectOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + struct ConstantTransposeOptimization : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -585,12 +606,15 @@ return {}; \ } -ReduceFolder(ReduceAllOp) ReduceFolder(ReduceAnyOp) ReduceFolder(ReduceMaxOp) - ReduceFolder(ReduceMinOp) ReduceFolder(ReduceProdOp) - ReduceFolder(ReduceSumOp) +ReduceFolder(ReduceAllOp); +ReduceFolder(ReduceAnyOp); +ReduceFolder(ReduceMaxOp); +ReduceFolder(ReduceMinOp); +ReduceFolder(ReduceProdOp); +ReduceFolder(ReduceSumOp); #undef ReduceFolder - OpFoldResult ReshapeOp::fold(ArrayRef operands) { +OpFoldResult ReshapeOp::fold(ArrayRef operands) { auto inputTy = input1().getType().dyn_cast(); auto outputTy = getType().dyn_cast(); @@ -623,6 +647,20 @@ return {}; } +OpFoldResult tosa::SelectOp::fold(ArrayRef operands) { + if (on_true() == on_false()) + return on_true(); + + auto predicate = operands[0].dyn_cast_or_null(); + if (!predicate) + return {}; + + if (!predicate.isSplat()) + return {}; + return predicate.getSplatValue().getBoolValue() ? on_true() + : on_false(); +} + OpFoldResult TileOp::fold(ArrayRef operands) { bool allOnes = true; for (Attribute val : multiples().getValue()) { @@ -1951,7 +1989,7 @@ resultKnowledge[index], ValueKnowledge::getKnowledgeFromType(it.value().getType()))) { resultKnowledge[index] = meet; - }; + } } } Index: mlir/test/Dialect/Tosa/canonicalize.mlir =================================================================== --- mlir/test/Dialect/Tosa/canonicalize.mlir +++ mlir/test/Dialect/Tosa/canonicalize.mlir @@ -252,6 +252,48 @@ // ----- +// CHECK-LABEL: @select_smae_value +func @select_smae_value(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> { + %0 = "tosa.select"(%arg0, %arg1, %arg1) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> + // CHECK: return %arg1 + // CHECK-NOT: tosa.select + return %0 : tensor<2x3xi32> +} + +// ----- + +// CHECK-LABEL: @select_true_value +func @select_true_value(%arg0: tensor<2x3xi32>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> { + %c1 = "tosa.const"() {value = dense<1> : tensor<2x3xi1>} : () -> tensor<2x3xi1> + %0 = "tosa.select"(%c1, %arg0, %arg1) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> + // CHECK: return %arg0 + // CHECK-NOT: tosa.select + return %0 : tensor<2x3xi32> +} + +// ----- + +// CHECK-LABEL: @select_false_value +func @select_false_value(%arg0: tensor<2x3xi32>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> { + %c0 = "tosa.const"() {value = dense<0> : tensor<2x3xi1>} : () -> tensor<2x3xi1> + %0 = "tosa.select"(%c0, %arg0, %arg1) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> + // CHECK: return %arg1 + // CHECK-NOT: tosa.select + return %0 : tensor<2x3xi32> +} + +// ----- + +// CHECK-LABEL: @select_not_pred +func @select_not_pred(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xi32>) -> tensor<2x3xi32> { + %0 = "tosa.logical_not"(%arg0) : (tensor<2x3xi1>) -> tensor<2x3xi1> + %1 = "tosa.select"(%0, %arg1, %arg2) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> + // CHECK: "tosa.select"(%arg0, %arg2, %arg1) + return %1 : tensor<2x3xi32> +} + +// ----- + // CHECK-LABEL: @reduce_all_fold func @reduce_all_fold(%arg0: tensor) -> tensor { // CHECK: return %arg0