diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp --- a/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp @@ -233,6 +233,89 @@ return success(); } }; + +template <> +struct ConvertTosaOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tosa::SelectOp tosaOp, + PatternRewriter &rewriter) const override { + + Value input1 = tosaOp.getOnTrue(); + Value input2 = tosaOp.getOnFalse(); + Value pred = tosaOp.getPred(); + Value output = tosaOp.getResult(); + + auto outputType = output.getType().dyn_cast(); + if (!outputType) + return rewriter.notifyMatchFailure(tosaOp, "output not a ranked tensor"); + + // The pass will repeatedly try to do pattern matching until it meets + // convergence. The broadcasting of the entire operation can be done in + // multiple rounds of the pass. Let's give an example that the three input + // tensors are a, b, and c where rank(a) < rank(b) < rank(c). + // + // "tosa.select"(%pred, %input1, %input2) : + // (rank, rank, rank) -> rank + // + // a) The first round of the pass: + // Fail on broadcast(%perd, %input1) as its result (b) doesn't match the + // output rank (c). Success on broadcast(%perd, %input2), then we get + // + // %pred` = "tosa.reshape"(%pred) -> rank + // + // b) The second round of the pass: + // Success on broadcast(%perd`, %input1), then we get + // + // %input1` = "tosa.reshape"(%input1) -> rank + // + // c) Nothing change within further rounds because the code cannot be + // rewritten more as it's already correct. Finally we get + // + // %pred` = "tosa.reshape"(%pred) -> rank + // %input1` = "tosa.reshape"(%input1) -> rank + // "tosa.select"(%pred`, %input1`, %input2) : + // (rank, rank, rank) -> rank + + Value newInput1, newInput2, newPred1, newPred2; + bool reshaped1 = reshapeLowerToHigher(rewriter, tosaOp.getLoc(), outputType, + pred, input1, newPred1, newInput1) + .succeeded(); + + bool reshaped2 = reshapeLowerToHigher(rewriter, tosaOp.getLoc(), outputType, + pred, input2, newPred2, newInput2) + .succeeded(); + + if (!reshaped1 && !reshaped2) + return rewriter.notifyMatchFailure( + tosaOp, + "operands are fail to broadcast, check if the ranks are valid"); + + // Verify the rank of operands agrees with each other. If both newPred1 and + // newPred2 are new, they are still the same broadcasting as + // reshapeLowerToHigher() checks if the reshaped rank of input arguments + // match with the rank of output. + if (reshaped1 && reshaped2) { + if (newPred1.getType().dyn_cast().getRank() != + newPred2.getType().dyn_cast().getRank()) { + return rewriter.notifyMatchFailure( + tosaOp, "when all operands have been reshaped at the same round, " + "both of them have to align with the maximum rank"); + } + } + + Value outInput1 = reshaped1 ? newInput1 : input1; + Value outInput2 = reshaped2 ? newInput2 : input2; + // No matter which predicate is selected since the rank of these has been + // verified above. + Value outPred = reshaped1 ? newPred1 : (reshaped2 ? newPred2 : pred); + + rewriter.replaceOpWithNewOp(tosaOp, outputType, outPred, + outInput1, outInput2); + + return success(); + } +}; } // namespace namespace { @@ -265,6 +348,7 @@ patterns.add>(ctx); patterns.add>(ctx); patterns.add>(ctx); + patterns.add>(ctx); (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); } }; diff --git a/mlir/test/Dialect/Tosa/broadcast.mlir b/mlir/test/Dialect/Tosa/broadcast.mlir --- a/mlir/test/Dialect/Tosa/broadcast.mlir +++ b/mlir/test/Dialect/Tosa/broadcast.mlir @@ -195,3 +195,91 @@ %0 = "tosa.add"(%arg0, %arg1) : (tensor, tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32> return %0 : tensor<17x16x15x14xi32> } + +// ----- +// CHECK-LABEL: broadcast_select_both_input +func.func @test_broadcast_select_both_input(%arg0: tensor<1x16x16xi1>, %arg1: tensor, %arg2: tensor) -> tensor<1x16x16xf32> { + // CHECK-DAG: %[[VAL_0:.*]] = "tosa.reshape"(%arg1) {new_shape = [1, 1, 1]} + // CHECK-DAG: %[[VAL_1:.*]] = "tosa.reshape"(%arg2) {new_shape = [1, 1, 1]} + // CHECK: %[[VAL_2:.*]] = "tosa.select"(%arg0, %[[VAL_0]], %[[VAL_1]]) + %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<1x16x16xi1>, tensor, tensor) -> tensor<1x16x16xf32> + return %0 : tensor<1x16x16xf32> +} + +// ----- +// CHECK-LABEL: broadcast_select_one_input +func.func @test_broadcast_select_one_input(%arg0: tensor<17x16x15x14xi1>, %arg1: tensor<17x16x15x14xf32>, %arg2: tensor) -> tensor<17x16x15x14xf32> { + // CHECK-DAG: %[[VAL_0:.*]] = "tosa.reshape"(%arg2) {new_shape = [1, 1, 1, 1]} + // CHECK: %[[VAL_1:.*]] = "tosa.select"(%arg0, %arg1, %[[VAL_0]]) + %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<17x16x15x14xi1>, tensor<17x16x15x14xf32>, tensor) -> tensor<17x16x15x14xf32> + return %0 : tensor<17x16x15x14xf32> +} + +// ----- +// CHECK-LABEL: broadcast_select_predicate +func.func @test_broadcast_select_predicate(%arg0: tensor, %arg1: tensor<1x32x32x8xf32>, %arg2: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> { + // CHECK-DAG: %[[VAL_0:.*]] = "tosa.reshape"(%arg0) {new_shape = [1, 1, 1, 1]} + // CHECK: %[[VAL_1:.*]] = "tosa.select"(%[[VAL_0]], %arg1, %arg2) + %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor, tensor<1x32x32x8xf32>, tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> + return %0 : tensor<1x32x32x8xf32> +} + +// ----- +// CHECK-LABEL: broadcast_select_abc +func.func @test_broadcast_select_abc(%arg0: tensor, %arg1: tensor<32x8xf32>, %arg2: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> { + // CHECK-DAG: %[[VAL_0:.*]] = "tosa.reshape"(%arg0) {new_shape = [1, 1, 1, 1]} + // CHECK-DAG: %[[VAL_1:.*]] = "tosa.reshape"(%arg1) {new_shape = [1, 1, 32, 8]} + // CHECK: %[[VAL_2:.*]] = "tosa.select"(%[[VAL_0]], %[[VAL_1]], %arg2) + %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor, tensor<32x8xf32>, tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> + return %0 : tensor<1x32x32x8xf32> +} + +// ----- +// CHECK-LABEL: broadcast_select_acb +func.func @test_broadcast_select_acb(%arg0: tensor, %arg1: tensor<1x32x32x8xf32>, %arg2: tensor<32x8xf32>) -> tensor<1x32x32x8xf32> { + // CHECK-DAG: %[[VAL_0:.*]] = "tosa.reshape"(%arg0) {new_shape = [1, 1, 1, 1]} + // CHECK-DAG: %[[VAL_1:.*]] = "tosa.reshape"(%arg2) {new_shape = [1, 1, 32, 8]} + // CHECK: %[[VAL_2:.*]] = "tosa.select"(%[[VAL_0]], %arg1, %[[VAL_1]]) + %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor, tensor<1x32x32x8xf32>, tensor<32x8xf32>) -> tensor<1x32x32x8xf32> + return %0 : tensor<1x32x32x8xf32> +} + +// ----- +// CHECK-LABEL: broadcast_select_bac +func.func @test_broadcast_select_bac(%arg0: tensor<32x8xi1>, %arg1: tensor, %arg2: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> { + // CHECK-DAG: %[[VAL_0:.*]] = "tosa.reshape"(%arg0) {new_shape = [1, 1, 32, 8]} + // CHECK-DAG: %[[VAL_1:.*]] = "tosa.reshape"(%arg1) {new_shape = [1, 1, 1, 1]} + // CHECK: %[[VAL_2:.*]] = "tosa.select"(%[[VAL_0]], %[[VAL_1]], %arg2) + %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<32x8xi1>, tensor, tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> + return %0 : tensor<1x32x32x8xf32> +} + +// ----- +// CHECK-LABEL: broadcast_select_bca +func.func @test_broadcast_select_bca(%arg0: tensor<32x8xi1>, %arg1: tensor<1x32x32x8xf32>, %arg2: tensor) -> tensor<1x32x32x8xf32> { + // CHECK-DAG: %[[VAL_0:.*]] = "tosa.reshape"(%arg0) {new_shape = [1, 1, 32, 8]} + // CHECK-DAG: %[[VAL_1:.*]] = "tosa.reshape"(%arg2) {new_shape = [1, 1, 1, 1]} + // CHECK: %[[VAL_2:.*]] = "tosa.select"(%[[VAL_0]], %arg1, %[[VAL_1]]) + %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<32x8xi1>, tensor<1x32x32x8xf32>, tensor) -> tensor<1x32x32x8xf32> + return %0 : tensor<1x32x32x8xf32> +} + +// ----- +// CHECK-LABEL: broadcast_select_cab +func.func @test_broadcast_select_cab(%arg0: tensor<1x32x32x8xi1>, %arg1: tensor, %arg2: tensor<32x8xf32>) -> tensor<1x32x32x8xf32> { + // CHECK-DAG: %[[VAL_0:.*]] = "tosa.reshape"(%arg1) {new_shape = [1, 1, 1, 1]} + // CHECK-DAG: %[[VAL_1:.*]] = "tosa.reshape"(%arg2) {new_shape = [1, 1, 32, 8]} + // CHECK: %[[VAL_2:.*]] = "tosa.select"(%arg0, %[[VAL_0]], %[[VAL_1]]) + %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<1x32x32x8xi1>, tensor, tensor<32x8xf32>) -> tensor<1x32x32x8xf32> + return %0 : tensor<1x32x32x8xf32> +} + +// ----- +// CHECK-LABEL: broadcast_select_cba +func.func @test_broadcast_select_cba(%arg0: tensor<1x32x32x8xi1>, %arg1: tensor<32x8xf32>, %arg2: tensor) -> tensor<1x32x32x8xf32> { + // CHECK-DAG: %[[VAL_0:.*]] = "tosa.reshape"(%arg1) {new_shape = [1, 1, 32, 8]} + // CHECK-DAG: %[[VAL_1:.*]] = "tosa.reshape"(%arg2) {new_shape = [1, 1, 1, 1]} + // CHECK: %[[VAL_2:.*]] = "tosa.select"(%arg0, %[[VAL_0]], %[[VAL_1]]) + %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<1x32x32x8xi1>, tensor<32x8xf32>, tensor) -> tensor<1x32x32x8xf32> + return %0 : tensor<1x32x32x8xf32> +}