diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp --- a/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp @@ -86,17 +86,18 @@ auto input1Ty = input1.getType().dyn_cast(); auto input2Ty = input2.getType().dyn_cast(); - if (!input1Ty || !input2Ty) - return failure(); + if (!input1Ty || !input2Ty) { + return rewriter.notifyMatchFailure(loc, "input not a ranked tensor"); + } int64_t input1Rank = input1Ty.getRank(); int64_t input2Rank = input2Ty.getRank(); - Value higherTensorValue, lowerTensorValue; - // Cannot rewrite as its already correct. if (input1Rank == input2Rank) - return failure(); + return rewriter.notifyMatchFailure(loc, + "cannot rewrite as its already correct"); + Value higherTensorValue, lowerTensorValue; if (input1Rank > input2Rank) { higherTensorValue = input1; lowerTensorValue = input2; @@ -107,7 +108,6 @@ ArrayRef higherRankShape = higherTensorValue.getType().cast().getShape(); - (void)higherRankShape; ArrayRef lowerRankShape = lowerTensorValue.getType().cast().getShape(); @@ -115,7 +115,7 @@ if (computeReshapeOutput(higherRankShape, lowerRankShape, reshapeOutputShape) .failed()) - return failure(); + return rewriter.notifyMatchFailure(loc, "fail to compute a reshape type"); auto reshapeInputType = lowerTensorValue.getType().cast(); auto reshapeOutputType = RankedTensorType::get( @@ -125,7 +125,8 @@ if (outputType) { if (outputType.getShape().size() != reshapeOutputShape.size() || outputType.getShape().size() != higherRankShape.size()) - return failure(); + return rewriter.notifyMatchFailure( + loc, "the reshaped type doesn't agrees with the ranked output type"); } auto reshapeLower = rewriter.create( @@ -144,7 +145,8 @@ } namespace { -template struct ConvertTosaOp : public OpRewritePattern { +template +struct ConvertTosaOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(OpTy tosaBinaryOp, @@ -232,6 +234,60 @@ return success(); } }; + +template <> +struct ConvertTosaOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tosa::SelectOp tosaOp, + PatternRewriter &rewriter) const override { + + Value input1 = tosaOp.getPred(); + Value input2 = tosaOp.getOnTrue(); + Value input3 = tosaOp.getOnFalse(); + Value output = tosaOp.getResult(); + + auto outputType = output.getType().dyn_cast(); + if (!outputType) + return rewriter.notifyMatchFailure(tosaOp, "output not a ranked tensor"); + + Value result1 = input1; + Value result2 = input2; + Value result3 = input3; + + // Apply broadcasting to each pair of inputs separately, and chain them as + // compound as below so that the broadcasting happens all at once. + bool reshaped1 = reshapeLowerToHigher(rewriter, tosaOp.getLoc(), outputType, + input1, input2, result1, result2) + .succeeded(); + + bool reshaped2 = reshapeLowerToHigher(rewriter, tosaOp.getLoc(), outputType, + result1, input3, result1, result3) + .succeeded(); + + bool reshaped3 = reshapeLowerToHigher(rewriter, tosaOp.getLoc(), outputType, + result2, input3, result2, result3) + .succeeded(); + + if (!reshaped1 && !reshaped2 && !reshaped3) + return rewriter.notifyMatchFailure( + tosaOp, + "cannot rewrite as the rank of all operands is already aligned"); + + int32_t result1Rank = result1.getType().cast().getRank(); + int32_t result2Rank = result2.getType().cast().getRank(); + int32_t result3Rank = result3.getType().cast().getRank(); + + if ((result1Rank != result2Rank) || (result2Rank != result3Rank)) + return rewriter.notifyMatchFailure( + tosaOp, "not all ranks are aligned with each other"); + + rewriter.replaceOpWithNewOp(tosaOp, outputType, result1, + result2, result3); + + return success(); + } +}; } // namespace namespace { @@ -263,6 +319,7 @@ patterns.add>(ctx); patterns.add>(ctx); patterns.add>(ctx); + patterns.add>(ctx); patterns.add>(ctx); (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); } diff --git a/mlir/test/Dialect/Tosa/broadcast.mlir b/mlir/test/Dialect/Tosa/broadcast.mlir --- a/mlir/test/Dialect/Tosa/broadcast.mlir +++ b/mlir/test/Dialect/Tosa/broadcast.mlir @@ -195,3 +195,91 @@ %0 = "tosa.add"(%arg0, %arg1) : (tensor, tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32> return %0 : tensor<17x16x15x14xi32> } + +// ----- +// CHECK-LABEL: broadcast_select_both_input +func.func @test_broadcast_select_both_input(%arg0: tensor<1x16x16xi1>, %arg1: tensor, %arg2: tensor) -> tensor<1x16x16xf32> { + // CHECK-DAG: %[[VAL_0:.*]] = "tosa.reshape"(%arg1) {new_shape = array} + // CHECK-DAG: %[[VAL_1:.*]] = "tosa.reshape"(%arg2) {new_shape = array} + // CHECK: %[[VAL_2:.*]] = "tosa.select"(%arg0, %[[VAL_0]], %[[VAL_1]]) + %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<1x16x16xi1>, tensor, tensor) -> tensor<1x16x16xf32> + return %0 : tensor<1x16x16xf32> +} + +// ----- +// CHECK-LABEL: broadcast_select_one_input +func.func @test_broadcast_select_one_input(%arg0: tensor<17x16x15x14xi1>, %arg1: tensor<17x16x15x14xf32>, %arg2: tensor) -> tensor<17x16x15x14xf32> { + // CHECK-DAG: %[[VAL_0:.*]] = "tosa.reshape"(%arg2) {new_shape = array} + // CHECK: %[[VAL_1:.*]] = "tosa.select"(%arg0, %arg1, %[[VAL_0]]) + %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<17x16x15x14xi1>, tensor<17x16x15x14xf32>, tensor) -> tensor<17x16x15x14xf32> + return %0 : tensor<17x16x15x14xf32> +} + +// ----- +// CHECK-LABEL: broadcast_select_predicate +func.func @test_broadcast_select_predicate(%arg0: tensor, %arg1: tensor<1x32x32x8xf32>, %arg2: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> { + // CHECK-DAG: %[[VAL_0:.*]] = "tosa.reshape"(%arg0) {new_shape = array} + // CHECK: %[[VAL_1:.*]] = "tosa.select"(%[[VAL_0]], %arg1, %arg2) + %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor, tensor<1x32x32x8xf32>, tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> + return %0 : tensor<1x32x32x8xf32> +} + +// ----- +// CHECK-LABEL: broadcast_select_abc +func.func @test_broadcast_select_abc(%arg0: tensor, %arg1: tensor<32x8xf32>, %arg2: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> { + // CHECK-DAG: %[[VAL_0:.*]] = "tosa.reshape"(%arg0) {new_shape = array} + // CHECK-DAG: %[[VAL_1:.*]] = "tosa.reshape"(%arg1) {new_shape = array} + // CHECK: %[[VAL_2:.*]] = "tosa.select"(%[[VAL_0]], %[[VAL_1]], %arg2) + %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor, tensor<32x8xf32>, tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> + return %0 : tensor<1x32x32x8xf32> +} + +// ----- +// CHECK-LABEL: broadcast_select_acb +func.func @test_broadcast_select_acb(%arg0: tensor, %arg1: tensor<1x32x32x8xf32>, %arg2: tensor<32x8xf32>) -> tensor<1x32x32x8xf32> { + // CHECK-DAG: %[[VAL_0:.*]] = "tosa.reshape"(%arg0) {new_shape = array} + // CHECK-DAG: %[[VAL_1:.*]] = "tosa.reshape"(%arg2) {new_shape = array} + // CHECK: %[[VAL_2:.*]] = "tosa.select"(%[[VAL_0]], %arg1, %[[VAL_1]]) + %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor, tensor<1x32x32x8xf32>, tensor<32x8xf32>) -> tensor<1x32x32x8xf32> + return %0 : tensor<1x32x32x8xf32> +} + +// ----- +// CHECK-LABEL: broadcast_select_bac +func.func @test_broadcast_select_bac(%arg0: tensor<32x8xi1>, %arg1: tensor, %arg2: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> { + // CHECK-DAG: %[[VAL_0:.*]] = "tosa.reshape"(%arg0) {new_shape = array} + // CHECK-DAG: %[[VAL_1:.*]] = "tosa.reshape"(%arg1) {new_shape = array} + // CHECK: %[[VAL_2:.*]] = "tosa.select"(%[[VAL_0]], %[[VAL_1]], %arg2) + %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<32x8xi1>, tensor, tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> + return %0 : tensor<1x32x32x8xf32> +} + +// ----- +// CHECK-LABEL: broadcast_select_bca +func.func @test_broadcast_select_bca(%arg0: tensor<32x8xi1>, %arg1: tensor<1x32x32x8xf32>, %arg2: tensor) -> tensor<1x32x32x8xf32> { + // CHECK-DAG: %[[VAL_0:.*]] = "tosa.reshape"(%arg0) {new_shape = array} + // CHECK-DAG: %[[VAL_1:.*]] = "tosa.reshape"(%arg2) {new_shape = array} + // CHECK: %[[VAL_2:.*]] = "tosa.select"(%[[VAL_0]], %arg1, %[[VAL_1]]) + %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<32x8xi1>, tensor<1x32x32x8xf32>, tensor) -> tensor<1x32x32x8xf32> + return %0 : tensor<1x32x32x8xf32> +} + +// ----- +// CHECK-LABEL: broadcast_select_cab +func.func @test_broadcast_select_cab(%arg0: tensor<1x32x32x8xi1>, %arg1: tensor, %arg2: tensor<32x8xf32>) -> tensor<1x32x32x8xf32> { + // CHECK-DAG: %[[VAL_0:.*]] = "tosa.reshape"(%arg1) {new_shape = array} + // CHECK-DAG: %[[VAL_1:.*]] = "tosa.reshape"(%arg2) {new_shape = array} + // CHECK: %[[VAL_2:.*]] = "tosa.select"(%arg0, %[[VAL_0]], %[[VAL_1]]) + %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<1x32x32x8xi1>, tensor, tensor<32x8xf32>) -> tensor<1x32x32x8xf32> + return %0 : tensor<1x32x32x8xf32> +} + +// ----- +// CHECK-LABEL: broadcast_select_cba +func.func @test_broadcast_select_cba(%arg0: tensor<1x32x32x8xi1>, %arg1: tensor<32x8xf32>, %arg2: tensor) -> tensor<1x32x32x8xf32> { + // CHECK-DAG: %[[VAL_0:.*]] = "tosa.reshape"(%arg1) {new_shape = array} + // CHECK-DAG: %[[VAL_1:.*]] = "tosa.reshape"(%arg2) {new_shape = array} + // CHECK: %[[VAL_2:.*]] = "tosa.select"(%arg0, %[[VAL_0]], %[[VAL_1]]) + %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<1x32x32x8xi1>, tensor<32x8xf32>, tensor) -> tensor<1x32x32x8xf32> + return %0 : tensor<1x32x32x8xf32> +}