Index: mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td =================================================================== --- mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -1120,14 +1120,16 @@ }]; let arguments = (ins - I1Tensor:$input1, - Tosa_Tensor:$input2, - Tosa_Tensor:$input3 + I1Tensor:$pred, + Tosa_Tensor:$on_true, + Tosa_Tensor:$on_false ); let results = (outs Tosa_Tensor:$output ); + let hasCanonicalizeMethod = 1; + let hasFolder = 1; } //===----------------------------------------------------------------------===// Index: mlir/lib/Dialect/Tosa/IR/TosaOps.cpp =================================================================== --- mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -192,6 +192,16 @@ results.add(context); } +LogicalResult SelectOp::canonicalize(SelectOp op, PatternRewriter &rewriter) { + auto notOp = op.pred().getDefiningOp(); + if (!notOp) return failure(); + rewriter.updateRootInPlace(op, [&]() { + op.getOperation()->setOperands( + {notOp.input1(), op.on_false(), op.on_true()}); + }); + return success(); +} + struct ConstantTransposeOptimization : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -585,12 +595,15 @@ return {}; \ } -ReduceFolder(ReduceAllOp) ReduceFolder(ReduceAnyOp) ReduceFolder(ReduceMaxOp) - ReduceFolder(ReduceMinOp) ReduceFolder(ReduceProdOp) - ReduceFolder(ReduceSumOp) +ReduceFolder(ReduceAllOp); +ReduceFolder(ReduceAnyOp); +ReduceFolder(ReduceMaxOp); +ReduceFolder(ReduceMinOp); +ReduceFolder(ReduceProdOp); +ReduceFolder(ReduceSumOp); #undef ReduceFolder - OpFoldResult ReshapeOp::fold(ArrayRef operands) { +OpFoldResult ReshapeOp::fold(ArrayRef operands) { auto inputTy = input1().getType().dyn_cast(); auto outputTy = getType().dyn_cast(); @@ -623,6 +636,20 @@ return {}; } +OpFoldResult tosa::SelectOp::fold(ArrayRef operands) { + if (on_true() == on_false()) + return on_true(); + + auto predicate = operands[0].dyn_cast_or_null(); + if (!predicate) + return {}; + + if (!predicate.isSplat()) + return {}; + return predicate.getSplatValue().getBoolValue() ? on_true() + : on_false(); +} + OpFoldResult TileOp::fold(ArrayRef operands) { bool allOnes = true; for (Attribute val : multiples().getValue()) { @@ -1951,7 +1978,7 @@ resultKnowledge[index], ValueKnowledge::getKnowledgeFromType(it.value().getType()))) { resultKnowledge[index] = meet; - }; + } } } Index: mlir/test/Dialect/Tosa/canonicalize.mlir =================================================================== --- mlir/test/Dialect/Tosa/canonicalize.mlir +++ mlir/test/Dialect/Tosa/canonicalize.mlir @@ -252,6 +252,48 @@ // ----- +// CHECK-LABEL: @select_same_value +func @select_same_value(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> { + %0 = "tosa.select"(%arg0, %arg1, %arg1) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> + // CHECK: return %arg1 + // CHECK-NOT: tosa.select + return %0 : tensor<2x3xi32> +} + +// ----- + +// CHECK-LABEL: @select_true_value +func @select_true_value(%arg0: tensor<2x3xi32>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> { + %c1 = "tosa.const"() {value = dense<1> : tensor<2x3xi1>} : () -> tensor<2x3xi1> + %0 = "tosa.select"(%c1, %arg0, %arg1) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> + // CHECK: return %arg0 + // CHECK-NOT: tosa.select + return %0 : tensor<2x3xi32> +} + +// ----- + +// CHECK-LABEL: @select_false_value +func @select_false_value(%arg0: tensor<2x3xi32>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> { + %c0 = "tosa.const"() {value = dense<0> : tensor<2x3xi1>} : () -> tensor<2x3xi1> + %0 = "tosa.select"(%c0, %arg0, %arg1) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> + // CHECK: return %arg1 + // CHECK-NOT: tosa.select + return %0 : tensor<2x3xi32> +} + +// ----- + +// CHECK-LABEL: @select_not_pred +func @select_not_pred(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xi32>) -> tensor<2x3xi32> { + %0 = "tosa.logical_not"(%arg0) : (tensor<2x3xi1>) -> tensor<2x3xi1> + %1 = "tosa.select"(%0, %arg1, %arg2) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> + // CHECK: "tosa.select"(%arg0, %arg2, %arg1) + return %1 : tensor<2x3xi32> +} + +// ----- + // CHECK-LABEL: @reduce_all_fold func @reduce_all_fold(%arg0: tensor) -> tensor { // CHECK: return %arg0