diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp --- a/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp @@ -75,28 +75,28 @@ } /// Common code to create the reshape op where necessary to make the rank of the -/// operations equal. Returns the updated input1 and input2 for the original -/// input. The caller is expected to use these to rewrite the original operator -/// with the RESHAPE now in the graph. +/// operations equal. input1 and input2 will be updated when the rank has +/// changed. The caller is expected to use these to rewrite the original +/// operator with the RESHAPE now in the graph. static LogicalResult reshapeLowerToHigher(PatternRewriter &rewriter, Location loc, RankedTensorType outputType, - Value input1, Value input2, - Value &outInput1, Value &outInput2) { + Value &input1, Value &input2) { auto input1Ty = input1.getType().dyn_cast(); auto input2Ty = input2.getType().dyn_cast(); - if (!input1Ty || !input2Ty) - return failure(); + if (!input1Ty || !input2Ty) { + return rewriter.notifyMatchFailure(loc, "input not a ranked tensor"); + } int64_t input1Rank = input1Ty.getRank(); int64_t input2Rank = input2Ty.getRank(); - Value higherTensorValue, lowerTensorValue; - // Cannot rewrite as its already correct. if (input1Rank == input2Rank) - return failure(); + return rewriter.notifyMatchFailure(loc, + "cannot rewrite as its already correct"); + Value higherTensorValue, lowerTensorValue; if (input1Rank > input2Rank) { higherTensorValue = input1; lowerTensorValue = input2; @@ -107,7 +107,6 @@ ArrayRef higherRankShape = higherTensorValue.getType().cast().getShape(); - (void)higherRankShape; ArrayRef lowerRankShape = lowerTensorValue.getType().cast().getShape(); @@ -115,7 +114,7 @@ if (computeReshapeOutput(higherRankShape, lowerRankShape, reshapeOutputShape) .failed()) - return failure(); + return rewriter.notifyMatchFailure(loc, "fail to compute a reshape type"); auto reshapeInputType = lowerTensorValue.getType().cast(); auto reshapeOutputType = RankedTensorType::get( @@ -125,7 +124,8 @@ if (outputType) { if (outputType.getShape().size() != reshapeOutputShape.size() || outputType.getShape().size() != higherRankShape.size()) - return failure(); + return rewriter.notifyMatchFailure( + loc, "the reshaped type doesn't agrees with the ranked output type"); } auto reshapeLower = rewriter.create( @@ -133,18 +133,19 @@ rewriter.getDenseI64ArrayAttr(reshapeOutputShape)); if (input1Rank > input2Rank) { - outInput1 = higherTensorValue; - outInput2 = reshapeLower.getResult(); + input1 = higherTensorValue; + input2 = reshapeLower.getResult(); } else { - outInput1 = reshapeLower.getResult(); - outInput2 = higherTensorValue; + input1 = reshapeLower.getResult(); + input2 = higherTensorValue; } return success(); } namespace { -template struct ConvertTosaOp : public OpRewritePattern { +template +struct ConvertTosaOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(OpTy tosaBinaryOp, @@ -158,14 +159,12 @@ if (!outputType) return failure(); - Value outInput1, outInput2; if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType, - input1, input2, outInput1, outInput2) + input1, input2) .failed()) return failure(); - rewriter.replaceOpWithNewOp(tosaBinaryOp, outputType, outInput1, - outInput2); + rewriter.replaceOpWithNewOp(tosaBinaryOp, outputType, input1, input2); return success(); } @@ -188,14 +187,13 @@ if (!outputType) return failure(); - Value outInput1, outInput2; if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType, - input1, input2, outInput1, outInput2) + input1, input2) .failed()) return failure(); - rewriter.replaceOpWithNewOp(tosaBinaryOp, outputType, - outInput1, outInput2, shift); + rewriter.replaceOpWithNewOp(tosaBinaryOp, outputType, input1, + input2, shift); return success(); } @@ -220,14 +218,63 @@ if (!outputType) return failure(); - Value outInput1, outInput2; if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType, - input1, input2, outInput1, outInput2) + input1, input2) .failed()) return failure(); rewriter.replaceOpWithNewOp( - tosaBinaryOp, outputType, outInput1, outInput2, round); + tosaBinaryOp, outputType, input1, input2, round); + + return success(); + } +}; + +template <> +struct ConvertTosaOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tosa::SelectOp tosaOp, + PatternRewriter &rewriter) const override { + + Value input1 = tosaOp.getPred(); + Value input2 = tosaOp.getOnTrue(); + Value input3 = tosaOp.getOnFalse(); + Value output = tosaOp.getResult(); + + auto outputType = output.getType().dyn_cast(); + if (!outputType) + return rewriter.notifyMatchFailure(tosaOp, "output not a ranked tensor"); + + // Apply broadcasting to each pair of inputs separately, and chain them as + // compound as below so that the broadcasting happens all at once. + bool reshaped1 = reshapeLowerToHigher(rewriter, tosaOp.getLoc(), outputType, + input1, input2) + .succeeded(); + + bool reshaped2 = reshapeLowerToHigher(rewriter, tosaOp.getLoc(), outputType, + input1, input3) + .succeeded(); + + bool reshaped3 = reshapeLowerToHigher(rewriter, tosaOp.getLoc(), outputType, + input2, input3) + .succeeded(); + + if (!reshaped1 && !reshaped2 && !reshaped3) + return rewriter.notifyMatchFailure( + tosaOp, + "cannot rewrite as the rank of all operands is already aligned"); + + int32_t result1Rank = input1.getType().cast().getRank(); + int32_t result2Rank = input2.getType().cast().getRank(); + int32_t result3Rank = input3.getType().cast().getRank(); + + if ((result1Rank != result2Rank) || (result2Rank != result3Rank)) + return rewriter.notifyMatchFailure( + tosaOp, "not all ranks are aligned with each other"); + + rewriter.replaceOpWithNewOp(tosaOp, outputType, input1, + input2, input3); return success(); } @@ -263,6 +310,7 @@ patterns.add>(ctx); patterns.add>(ctx); patterns.add>(ctx); + patterns.add>(ctx); patterns.add>(ctx); (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); } diff --git a/mlir/test/Dialect/Tosa/broadcast.mlir b/mlir/test/Dialect/Tosa/broadcast.mlir --- a/mlir/test/Dialect/Tosa/broadcast.mlir +++ b/mlir/test/Dialect/Tosa/broadcast.mlir @@ -195,3 +195,91 @@ %0 = "tosa.add"(%arg0, %arg1) : (tensor, tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32> return %0 : tensor<17x16x15x14xi32> } + +// ----- +// CHECK-LABEL: broadcast_select_both_input +func.func @test_broadcast_select_both_input(%arg0: tensor<1x16x16xi1>, %arg1: tensor, %arg2: tensor) -> tensor<1x16x16xf32> { + // CHECK-DAG: %[[VAL_0:.*]] = "tosa.reshape"(%arg1) {new_shape = array} + // CHECK-DAG: %[[VAL_1:.*]] = "tosa.reshape"(%arg2) {new_shape = array} + // CHECK: %[[VAL_2:.*]] = "tosa.select"(%arg0, %[[VAL_0]], %[[VAL_1]]) + %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<1x16x16xi1>, tensor, tensor) -> tensor<1x16x16xf32> + return %0 : tensor<1x16x16xf32> +} + +// ----- +// CHECK-LABEL: broadcast_select_one_input +func.func @test_broadcast_select_one_input(%arg0: tensor<17x16x15x14xi1>, %arg1: tensor<17x16x15x14xf32>, %arg2: tensor) -> tensor<17x16x15x14xf32> { + // CHECK-DAG: %[[VAL_0:.*]] = "tosa.reshape"(%arg2) {new_shape = array} + // CHECK: %[[VAL_1:.*]] = "tosa.select"(%arg0, %arg1, %[[VAL_0]]) + %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<17x16x15x14xi1>, tensor<17x16x15x14xf32>, tensor) -> tensor<17x16x15x14xf32> + return %0 : tensor<17x16x15x14xf32> +} + +// ----- +// CHECK-LABEL: broadcast_select_predicate +func.func @test_broadcast_select_predicate(%arg0: tensor, %arg1: tensor<1x32x32x8xf32>, %arg2: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> { + // CHECK-DAG: %[[VAL_0:.*]] = "tosa.reshape"(%arg0) {new_shape = array} + // CHECK: %[[VAL_1:.*]] = "tosa.select"(%[[VAL_0]], %arg1, %arg2) + %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor, tensor<1x32x32x8xf32>, tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> + return %0 : tensor<1x32x32x8xf32> +} + +// ----- +// CHECK-LABEL: broadcast_select_abc +func.func @test_broadcast_select_abc(%arg0: tensor, %arg1: tensor<32x8xf32>, %arg2: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> { + // CHECK-DAG: %[[VAL_0:.*]] = "tosa.reshape"(%arg0) {new_shape = array} + // CHECK-DAG: %[[VAL_1:.*]] = "tosa.reshape"(%arg1) {new_shape = array} + // CHECK: %[[VAL_2:.*]] = "tosa.select"(%[[VAL_0]], %[[VAL_1]], %arg2) + %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor, tensor<32x8xf32>, tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> + return %0 : tensor<1x32x32x8xf32> +} + +// ----- +// CHECK-LABEL: broadcast_select_acb +func.func @test_broadcast_select_acb(%arg0: tensor, %arg1: tensor<1x32x32x8xf32>, %arg2: tensor<32x8xf32>) -> tensor<1x32x32x8xf32> { + // CHECK-DAG: %[[VAL_0:.*]] = "tosa.reshape"(%arg0) {new_shape = array} + // CHECK-DAG: %[[VAL_1:.*]] = "tosa.reshape"(%arg2) {new_shape = array} + // CHECK: %[[VAL_2:.*]] = "tosa.select"(%[[VAL_0]], %arg1, %[[VAL_1]]) + %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor, tensor<1x32x32x8xf32>, tensor<32x8xf32>) -> tensor<1x32x32x8xf32> + return %0 : tensor<1x32x32x8xf32> +} + +// ----- +// CHECK-LABEL: broadcast_select_bac +func.func @test_broadcast_select_bac(%arg0: tensor<32x8xi1>, %arg1: tensor, %arg2: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> { + // CHECK-DAG: %[[VAL_0:.*]] = "tosa.reshape"(%arg0) {new_shape = array} + // CHECK-DAG: %[[VAL_1:.*]] = "tosa.reshape"(%arg1) {new_shape = array} + // CHECK: %[[VAL_2:.*]] = "tosa.select"(%[[VAL_0]], %[[VAL_1]], %arg2) + %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<32x8xi1>, tensor, tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> + return %0 : tensor<1x32x32x8xf32> +} + +// ----- +// CHECK-LABEL: broadcast_select_bca +func.func @test_broadcast_select_bca(%arg0: tensor<32x8xi1>, %arg1: tensor<1x32x32x8xf32>, %arg2: tensor) -> tensor<1x32x32x8xf32> { + // CHECK-DAG: %[[VAL_0:.*]] = "tosa.reshape"(%arg0) {new_shape = array} + // CHECK-DAG: %[[VAL_1:.*]] = "tosa.reshape"(%arg2) {new_shape = array} + // CHECK: %[[VAL_2:.*]] = "tosa.select"(%[[VAL_0]], %arg1, %[[VAL_1]]) + %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<32x8xi1>, tensor<1x32x32x8xf32>, tensor) -> tensor<1x32x32x8xf32> + return %0 : tensor<1x32x32x8xf32> +} + +// ----- +// CHECK-LABEL: broadcast_select_cab +func.func @test_broadcast_select_cab(%arg0: tensor<1x32x32x8xi1>, %arg1: tensor, %arg2: tensor<32x8xf32>) -> tensor<1x32x32x8xf32> { + // CHECK-DAG: %[[VAL_0:.*]] = "tosa.reshape"(%arg1) {new_shape = array} + // CHECK-DAG: %[[VAL_1:.*]] = "tosa.reshape"(%arg2) {new_shape = array} + // CHECK: %[[VAL_2:.*]] = "tosa.select"(%arg0, %[[VAL_0]], %[[VAL_1]]) + %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<1x32x32x8xi1>, tensor, tensor<32x8xf32>) -> tensor<1x32x32x8xf32> + return %0 : tensor<1x32x32x8xf32> +} + +// ----- +// CHECK-LABEL: broadcast_select_cba +func.func @test_broadcast_select_cba(%arg0: tensor<1x32x32x8xi1>, %arg1: tensor<32x8xf32>, %arg2: tensor) -> tensor<1x32x32x8xf32> { + // CHECK-DAG: %[[VAL_0:.*]] = "tosa.reshape"(%arg1) {new_shape = array} + // CHECK-DAG: %[[VAL_1:.*]] = "tosa.reshape"(%arg2) {new_shape = array} + // CHECK: %[[VAL_2:.*]] = "tosa.select"(%arg0, %[[VAL_0]], %[[VAL_1]]) + %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<1x32x32x8xi1>, tensor<32x8xf32>, tensor) -> tensor<1x32x32x8xf32> + return %0 : tensor<1x32x32x8xf32> +}