Index: mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td =================================================================== --- mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -1120,14 +1120,16 @@ }]; let arguments = (ins - I1Tensor:$input1, - Tosa_Tensor:$input2, - Tosa_Tensor:$input3 + I1Tensor:$pred, + Tosa_Tensor:$on_true, + Tosa_Tensor:$on_false ); let results = (outs Tosa_Tensor:$output ); + let hasCanonicalizer = 1; + let hasFolder = 1; } //===----------------------------------------------------------------------===// Index: mlir/lib/Dialect/Tosa/IR/TosaOps.cpp =================================================================== --- mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -192,6 +192,28 @@ results.add(context); } +struct LogicalNotPredSelectOptimization + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tosa::SelectOp selectOp, + PatternRewriter &rewriter) const override { + auto notOp = selectOp.pred().getDefiningOp(); + if (!notOp) { + return failure(); + } + std::array newOperands = {notOp.input1(), selectOp.on_false(), + selectOp.on_true()}; + selectOp.getOperation()->setOperands(newOperands); + return success(); + } +}; + +void tosa::SelectOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + struct ConstantTransposeOptimization : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -575,29 +597,31 @@ return valueAttr(); } -#define ReduceFolder(OP) \ - OpFoldResult OP::fold(ArrayRef operands) { \ - ShapedType inputTy = input().getType().cast(); \ - if (!inputTy.hasRank()) \ - return {}; \ - if (inputTy.getDimSize(axis()) == 1) \ - return input(); \ - return {}; \ - } - -ReduceFolder(ReduceAllOp) ReduceFolder(ReduceAnyOp) ReduceFolder(ReduceMaxOp) - ReduceFolder(ReduceMinOp) ReduceFolder(ReduceProdOp) - ReduceFolder(ReduceSumOp) +// clang-format off +#define ReduceFolder(OP) \ + OpFoldResult OP::fold(ArrayRef operands) { \ + ShapedType inputTy = input().getType().cast(); \ + if (!inputTy.hasRank()) return {}; \ + if (inputTy.getDimSize(axis()) == 1) return input(); \ + return {}; \ + } + +ReduceFolder(ReduceAllOp) +ReduceFolder(ReduceAnyOp) +ReduceFolder(ReduceMaxOp) +ReduceFolder(ReduceMinOp) +ReduceFolder(ReduceProdOp) +ReduceFolder(ReduceSumOp) #undef ReduceFolder - OpFoldResult ReshapeOp::fold(ArrayRef operands) { +OpFoldResult ReshapeOp::fold(ArrayRef operands) { auto inputTy = input1().getType().dyn_cast(); auto outputTy = getType().dyn_cast(); - if (!inputTy || !outputTy || inputTy != outputTy) - return {}; + if (!inputTy || !outputTy || inputTy != outputTy) return {}; return input1(); } +// clang-format on OpFoldResult PadOp::fold(ArrayRef operands) { // If the pad is all zeros we can fold this operation away. @@ -623,6 +647,29 @@ return {}; } +OpFoldResult tosa::SelectOp::fold(ArrayRef operands) { + if (on_true() == on_false()) { + return on_true(); + } + + auto predicate = operands[0].dyn_cast_or_null(); + if (!predicate) { + return {}; + } + + auto predicateTy = predicate.getType().cast(); + if (!predicateTy.getElementType().isInteger(1)) { + return {}; + } + + if (predicate.isSplat()) { + return predicate.getSplatValue().getBoolValue() ? on_true() + : on_false(); + } + + return {}; +} + OpFoldResult TileOp::fold(ArrayRef operands) { bool allOnes = true; for (Attribute val : multiples().getValue()) { @@ -1951,7 +1998,7 @@ resultKnowledge[index], ValueKnowledge::getKnowledgeFromType(it.value().getType()))) { resultKnowledge[index] = meet; - }; + } } } Index: mlir/test/Dialect/Tosa/canonicalize.mlir =================================================================== --- mlir/test/Dialect/Tosa/canonicalize.mlir +++ mlir/test/Dialect/Tosa/canonicalize.mlir @@ -252,6 +252,48 @@ // ----- +// CHECK-LABEL: @select_smae_value +func @select_smae_value(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> { + %0 = "tosa.select"(%arg0, %arg1, %arg1) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> + // CHECK: return %arg1 + // CHECK-NOT: tosa.select + return %0 : tensor<2x3xi32> +} + +// ----- + +// CHECK-LABEL: @select_true_value +func @select_true_value(%arg0: tensor<2x3xi32>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> { + %c1 = "tosa.const"() {value = dense<1> : tensor<2x3xi1>} : () -> tensor<2x3xi1> + %0 = "tosa.select"(%c1, %arg0, %arg1) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> + // CHECK: return %arg0 + // CHECK-NOT: tosa.select + return %0 : tensor<2x3xi32> +} + +// ----- + +// CHECK-LABEL: @select_false_value +func @select_false_value(%arg0: tensor<2x3xi32>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> { + %c0 = "tosa.const"() {value = dense<0> : tensor<2x3xi1>} : () -> tensor<2x3xi1> + %0 = "tosa.select"(%c0, %arg0, %arg1) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> + // CHECK: return %arg1 + // CHECK-NOT: tosa.select + return %0 : tensor<2x3xi32> +} + +// ----- + +// CHECK-LABEL: @select_not_pred +func @select_not_pred(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xi32>) -> tensor<2x3xi32> { + %0 = "tosa.logical_not"(%arg0) : (tensor<2x3xi1>) -> tensor<2x3xi1> + %1 = "tosa.select"(%0, %arg1, %arg2) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> + // CHECK: "tosa.select"(%arg0, %arg2, %arg1) + return %1 : tensor<2x3xi32> +} + +// ----- + // CHECK-LABEL: @reduce_all_fold func @reduce_all_fold(%arg0: tensor) -> tensor { // CHECK: return %arg0