diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp --- a/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp @@ -143,6 +143,98 @@ return success(); } +static LogicalResult reshapeLowerToHigher(PatternRewriter &rewriter, + Location loc, + RankedTensorType outputType, + Value input1, Value input2, + Value input3, Value &outInput1, + Value &outInput2, Value &outInput3) { + SmallVector inputs; + inputs.push_back(input1); + inputs.push_back(input2); + inputs.push_back(input3); + + SmallVector, 3> outInputs; + outInputs.push_back(outInput1); + outInputs.push_back(outInput2); + outInputs.push_back(outInput3); + + // Create a vector of pair for sorting. + SmallVector, 3> ranks; + for (int i = 0; i < 3; ++i) { + auto inputTy = inputs[i].getType().dyn_cast(); + if (!inputTy) + return rewriter.notifyMatchFailure(loc, "input not a ranked tensor"); + + ranks.push_back(std::make_pair(inputTy.getRank(), i)); + } + + // After sorting, the last element of vector is the highest-rank tensor. + sort(ranks.begin(), ranks.end()); + + // Initialization. + outInput1 = input1; + outInput2 = input2; + outInput3 = input3; + int higherTensorRank = ranks[2].first; + int higherTensorIndex = ranks[2].second; + Value higherTensorValue = inputs[higherTensorIndex]; + + bool reshaped = false; + for (int i = 0; i < 2; ++i) { + int lowerTensorRank = ranks[i].first; + + if (lowerTensorRank != higherTensorRank) { + int lowerTensorIndex = ranks[i].second; + Value lowerTensorValue = inputs[lowerTensorIndex]; + + ArrayRef higherRankShape = + higherTensorValue.getType().cast().getShape(); + (void)higherRankShape; + ArrayRef lowerRankShape = + lowerTensorValue.getType().cast().getShape(); + + SmallVector reshapeOutputShape; + if (computeReshapeOutput(higherRankShape, lowerRankShape, + reshapeOutputShape) + .failed()) + return rewriter.notifyMatchFailure( + loc, + "fail to compute a reshaped output for a lower rank input tensor"); + + auto reshapeInputType = + lowerTensorValue.getType().cast(); + auto reshapeOutputType = + RankedTensorType::get(ArrayRef(reshapeOutputShape), + reshapeInputType.getElementType()); + + // Verify the rank agrees with the output type if the output type is + // ranked. + if (outputType) { + if (outputType.getShape().size() != reshapeOutputShape.size() || + outputType.getShape().size() != higherRankShape.size()) + return rewriter.notifyMatchFailure( + loc, "the reshaped rank doesn't agrees with the output type"); + } + + auto reshapeLower = rewriter.create( + loc, reshapeOutputType, lowerTensorValue, + rewriter.getI64ArrayAttr(reshapeOutputShape)); + + outInputs[lowerTensorIndex].get() = reshapeLower.getResult(); + reshaped = true; + } + } + + if (!reshaped) + return rewriter.notifyMatchFailure( + loc, "cannot rewrite as the ranks of inputs are already correct"); + + outInputs[higherTensorIndex].get() = higherTensorValue; + + return success(); +} + namespace { template struct ConvertTosaOp : public OpRewritePattern { @@ -233,6 +325,36 @@ return success(); } }; + +template <> +struct ConvertTosaOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tosa::SelectOp tosaOp, + PatternRewriter &rewriter) const override { + + Value input1 = tosaOp.getOnTrue(); + Value input2 = tosaOp.getOnFalse(); + Value pred = tosaOp.getPred(); + Value output = tosaOp.getResult(); + + auto outputType = output.getType().dyn_cast(); + if (!outputType) + return rewriter.notifyMatchFailure(tosaOp, "output not a ranked tensor"); + + Value outInput1, outInput2, outPred; + if (reshapeLowerToHigher(rewriter, tosaOp.getLoc(), outputType, pred, + input1, input2, outPred, outInput1, outInput2) + .failed()) + return rewriter.notifyMatchFailure( + tosaOp, "fail to broadcast a tensor, check if the ranks are valid"); + + rewriter.replaceOpWithNewOp(tosaOp, outputType, outPred, + outInput1, outInput2); + + return success(); + } +}; } // namespace namespace { @@ -265,6 +387,7 @@ patterns.add>(ctx); patterns.add>(ctx); patterns.add>(ctx); + patterns.add>(ctx); (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); } }; diff --git a/mlir/test/Dialect/Tosa/broadcast.mlir b/mlir/test/Dialect/Tosa/broadcast.mlir --- a/mlir/test/Dialect/Tosa/broadcast.mlir +++ b/mlir/test/Dialect/Tosa/broadcast.mlir @@ -195,3 +195,91 @@ %0 = "tosa.add"(%arg0, %arg1) : (tensor, tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32> return %0 : tensor<17x16x15x14xi32> } + +// ----- +// CHECK-LABEL: broadcast_select_both_input +func.func @test_broadcast_select_both_input(%arg0: tensor<1x16x16xi1>, %arg1: tensor, %arg2: tensor) -> tensor<1x16x16xf32> { + // CHECK-DAG: %[[VAL_0:.*]] = "tosa.reshape"(%arg1) {new_shape = [1, 1, 1]} + // CHECK-DAG: %[[VAL_1:.*]] = "tosa.reshape"(%arg2) {new_shape = [1, 1, 1]} + // CHECK: %[[VAL_2:.*]] = "tosa.select"(%arg0, %[[VAL_0]], %[[VAL_1]]) + %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<1x16x16xi1>, tensor, tensor) -> tensor<1x16x16xf32> + return %0 : tensor<1x16x16xf32> +} + +// ----- +// CHECK-LABEL: broadcast_select_one_input +func.func @test_broadcast_select_one_input(%arg0: tensor<17x16x15x14xi1>, %arg1: tensor<17x16x15x14xf32>, %arg2: tensor) -> tensor<17x16x15x14xf32> { + // CHECK-DAG: %[[VAL_0:.*]] = "tosa.reshape"(%arg2) {new_shape = [1, 1, 1, 1]} + // CHECK: %[[VAL_1:.*]] = "tosa.select"(%arg0, %arg1, %[[VAL_0]]) + %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<17x16x15x14xi1>, tensor<17x16x15x14xf32>, tensor) -> tensor<17x16x15x14xf32> + return %0 : tensor<17x16x15x14xf32> +} + +// ----- +// CHECK-LABEL: broadcast_select_predicate +func.func @test_broadcast_select_predicate(%arg0: tensor, %arg1: tensor<1x32x32x8xf32>, %arg2: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> { + // CHECK-DAG: %[[VAL_0:.*]] = "tosa.reshape"(%arg0) {new_shape = [1, 1, 1, 1]} + // CHECK: %[[VAL_1:.*]] = "tosa.select"(%[[VAL_0]], %arg1, %arg2) + %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor, tensor<1x32x32x8xf32>, tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> + return %0 : tensor<1x32x32x8xf32> +} + +// ----- +// CHECK-LABEL: broadcast_select_abc +func.func @test_broadcast_select_abc(%arg0: tensor, %arg1: tensor<32x8xf32>, %arg2: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> { + // CHECK-DAG: %[[VAL_0:.*]] = "tosa.reshape"(%arg0) {new_shape = [1, 1, 1, 1]} + // CHECK-DAG: %[[VAL_1:.*]] = "tosa.reshape"(%arg1) {new_shape = [1, 1, 32, 8]} + // CHECK: %[[VAL_2:.*]] = "tosa.select"(%[[VAL_0]], %[[VAL_1]], %arg2) + %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor, tensor<32x8xf32>, tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> + return %0 : tensor<1x32x32x8xf32> +} + +// ----- +// CHECK-LABEL: broadcast_select_acb +func.func @test_broadcast_select_acb(%arg0: tensor, %arg1: tensor<1x32x32x8xf32>, %arg2: tensor<32x8xf32>) -> tensor<1x32x32x8xf32> { + // CHECK-DAG: %[[VAL_0:.*]] = "tosa.reshape"(%arg0) {new_shape = [1, 1, 1, 1]} + // CHECK-DAG: %[[VAL_1:.*]] = "tosa.reshape"(%arg2) {new_shape = [1, 1, 32, 8]} + // CHECK: %[[VAL_2:.*]] = "tosa.select"(%[[VAL_0]], %arg1, %[[VAL_1]]) + %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor, tensor<1x32x32x8xf32>, tensor<32x8xf32>) -> tensor<1x32x32x8xf32> + return %0 : tensor<1x32x32x8xf32> +} + +// ----- +// CHECK-LABEL: broadcast_select_bac +func.func @test_broadcast_select_bac(%arg0: tensor<32x8xi1>, %arg1: tensor, %arg2: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> { + // CHECK-DAG: %[[VAL_0:.*]] = "tosa.reshape"(%arg0) {new_shape = [1, 1, 32, 8]} + // CHECK-DAG: %[[VAL_1:.*]] = "tosa.reshape"(%arg1) {new_shape = [1, 1, 1, 1]} + // CHECK: %[[VAL_2:.*]] = "tosa.select"(%[[VAL_0]], %[[VAL_1]], %arg2) + %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<32x8xi1>, tensor, tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> + return %0 : tensor<1x32x32x8xf32> +} + +// ----- +// CHECK-LABEL: broadcast_select_bca +func.func @test_broadcast_select_bca(%arg0: tensor<32x8xi1>, %arg1: tensor<1x32x32x8xf32>, %arg2: tensor) -> tensor<1x32x32x8xf32> { + // CHECK-DAG: %[[VAL_0:.*]] = "tosa.reshape"(%arg0) {new_shape = [1, 1, 32, 8]} + // CHECK-DAG: %[[VAL_1:.*]] = "tosa.reshape"(%arg2) {new_shape = [1, 1, 1, 1]} + // CHECK: %[[VAL_2:.*]] = "tosa.select"(%[[VAL_0]], %arg1, %[[VAL_1]]) + %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<32x8xi1>, tensor<1x32x32x8xf32>, tensor) -> tensor<1x32x32x8xf32> + return %0 : tensor<1x32x32x8xf32> +} + +// ----- +// CHECK-LABEL: broadcast_select_cab +func.func @test_broadcast_select_cab(%arg0: tensor<1x32x32x8xi1>, %arg1: tensor, %arg2: tensor<32x8xf32>) -> tensor<1x32x32x8xf32> { + // CHECK-DAG: %[[VAL_0:.*]] = "tosa.reshape"(%arg1) {new_shape = [1, 1, 1, 1]} + // CHECK-DAG: %[[VAL_1:.*]] = "tosa.reshape"(%arg2) {new_shape = [1, 1, 32, 8]} + // CHECK: %[[VAL_2:.*]] = "tosa.select"(%arg0, %[[VAL_0]], %[[VAL_1]]) + %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<1x32x32x8xi1>, tensor, tensor<32x8xf32>) -> tensor<1x32x32x8xf32> + return %0 : tensor<1x32x32x8xf32> +} + +// ----- +// CHECK-LABEL: broadcast_select_cba +func.func @test_broadcast_select_cba(%arg0: tensor<1x32x32x8xi1>, %arg1: tensor<32x8xf32>, %arg2: tensor) -> tensor<1x32x32x8xf32> { + // CHECK-DAG: %[[VAL_0:.*]] = "tosa.reshape"(%arg1) {new_shape = [1, 1, 32, 8]} + // CHECK-DAG: %[[VAL_1:.*]] = "tosa.reshape"(%arg2) {new_shape = [1, 1, 1, 1]} + // CHECK: %[[VAL_2:.*]] = "tosa.select"(%arg0, %[[VAL_0]], %[[VAL_1]]) + %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<1x32x32x8xi1>, tensor<32x8xf32>, tensor) -> tensor<1x32x32x8xf32> + return %0 : tensor<1x32x32x8xf32> +}