diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp --- a/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp @@ -233,6 +233,54 @@ return success(); } }; + +template <> +struct ConvertTosaOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tosa::SelectOp tosaOp, + PatternRewriter &rewriter) const override { + + Value input1 = tosaOp.getOnTrue(); + Value input2 = tosaOp.getOnFalse(); + Value pred = tosaOp.getPred(); + Value output = tosaOp.getResult(); + + auto outputType = output.getType().dyn_cast(); + if (!outputType) + return failure(); + + Value newInput1, newInput2, newPred1, newPred2; + bool reshaped1 = reshapeLowerToHigher(rewriter, tosaOp.getLoc(), outputType, + pred, input1, newPred1, newInput1) + .succeeded(); + + bool reshaped2 = reshapeLowerToHigher(rewriter, tosaOp.getLoc(), outputType, + pred, input2, newPred2, newInput2) + .succeeded(); + + // Both operands are fail to reshape, no need to perform broadcast. + if (!reshaped1 && !reshaped2) + return failure(); + + // Verify the rank of operands agrees with each other. + if (reshaped1 && reshaped2) { + if (newPred1.getType().dyn_cast().getRank() != + newPred2.getType().dyn_cast().getRank()) { + return failure(); + } + } + + Value outInput1 = reshaped1 ? newInput1 : input1; + Value outInput2 = reshaped2 ? newInput2 : input2; + Value outPred = reshaped1 ? newPred1 : (reshaped2 ? newPred2 : pred); + + rewriter.replaceOpWithNewOp(tosaOp, outputType, outPred, + outInput1, outInput2); + + return success(); + } +}; } // namespace namespace { @@ -265,6 +313,7 @@ patterns.add>(ctx); patterns.add>(ctx); patterns.add>(ctx); + patterns.add>(ctx); (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); } }; diff --git a/mlir/test/Dialect/Tosa/broadcast.mlir b/mlir/test/Dialect/Tosa/broadcast.mlir --- a/mlir/test/Dialect/Tosa/broadcast.mlir +++ b/mlir/test/Dialect/Tosa/broadcast.mlir @@ -195,3 +195,32 @@ %0 = "tosa.add"(%arg0, %arg1) : (tensor, tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32> return %0 : tensor<17x16x15x14xi32> } + +// ----- +// CHECK-LABEL: broadcast_select1 +func.func @test_broadcast_select1(%arg0: tensor<1x16x16xi1>, %arg1: tensor, %arg2: tensor) -> tensor<1x16x16xf32> { + // CHECK-DAG: %[[VAL_0:.*]] = "tosa.reshape"(%arg1) {new_shape = [1, 1, 1]} + // CHECK-DAG: %[[VAL_1:.*]] = "tosa.reshape"(%arg2) {new_shape = [1, 1, 1]} + // CHECK: %[[VAL_2:.*]] = "tosa.select"(%arg0, %[[VAL_0]], %[[VAL_1]]) + %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<1x16x16xi1>, tensor, tensor) -> tensor<1x16x16xf32> + return %0 : tensor<1x16x16xf32> +} + +// ----- +// CHECK-LABEL: broadcast_select2 +func.func @test_broadcast_select2(%arg0: tensor<17x16x15x14xi1>, %arg1: tensor<17x16x15x14xf32>, %arg2: tensor) -> tensor<17x16x15x14xf32> { + // CHECK-DAG: %[[VAL_0:.*]] = "tosa.reshape"(%arg2) {new_shape = [1, 1, 1, 1]} + // CHECK: %[[VAL_1:.*]] = "tosa.select"(%arg0, %arg1, %[[VAL_0]]) + %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<17x16x15x14xi1>, tensor<17x16x15x14xf32>, tensor) -> tensor<17x16x15x14xf32> + return %0 : tensor<17x16x15x14xf32> +} + +// ----- +// CHECK-LABEL: broadcast_select3 +func.func @test_broadcast_select3(%arg0: tensor, %arg1: tensor<32x8xf32>, %arg2: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> { + // CHECK-DAG: %[[VAL_0:.*]] = "tosa.reshape"(%arg0) {new_shape = [1, 1, 1, 1]} + // CHECK-DAG: %[[VAL_1:.*]] = "tosa.reshape"(%arg1) {new_shape = [1, 1, 32, 8]} + // CHECK: %[[VAL_2:.*]] = "tosa.select"(%[[VAL_0]], %[[VAL_1]], %arg2) + %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor, tensor<32x8xf32>, tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> + return %0 : tensor<1x32x32x8xf32> +}