diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp --- a/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp @@ -25,7 +25,7 @@ /// There are two potential ways implementing broadcast: /// a. https://www.tensorflow.org/xla/broadcasting#formal_definition /// b. https://numpy.org/doc/stable/user/basics.broadcasting.html -/// TBD: picking option (a) now. +/// This pass implements b (numpy style) now. /// In this pass, we insert RESHAPE operators to increase the rank of the /// lower rank operand as a first step in the broadcasting process. The TOSA @@ -33,75 +33,39 @@ /// are equal. // Examples: -// If lower=[a], target=[a, b, c], [a] reshaped into [a, 1, 1]. -// TODO: If lower=[b], target=[a, b, c], [b] should but NOT YET reshaped into -// [1, b, 1]. -// If lower=[c], target=[a, b, c], [c] reshaped into [1, 1, c]. -// If lower=[a, c], target=[a, b, c], [a, c] reshaped into [a, 1, c]. -// If lower=[a, b], target=[a, b, c], [a, b] reshaped into [a, b, 1]. -// If lower=[b, c], target=[a, b, c], [b, c] reshaped into [1, b, c]. -// If lower=[a], target=[a, a], [a] reshaped into [1, a] instead of [a, 1]. +// If lower=[c], higher=[a, b, c], [c] reshaped into [1, 1, c]. +// If lower=[b, c], higher=[a, b, c], [b, c] reshaped into [1, b, c]. +// If lower=[a], higher=[a, a], [a] reshaped into [1, a]. // If lower=[a], target=[a, b, a], [a] reshaped into [1, 1, a]. // If lower=[], target=[a, b, c], [] reshaped into [1, 1, 1]. -static void computeReshapeOutput(ArrayRef higherRankShape, - ArrayRef lowerRankShape, - SmallVectorImpl &reshapeOutputShape) { +static LogicalResult +computeReshapeOutput(ArrayRef higherRankShape, + ArrayRef lowerRankShape, + SmallVectorImpl &reshapeOutputShape) { // Initialize new shapes with [1] * higherRank. int64_t higherRank = higherRankShape.size(); int64_t lowerRank = lowerRankShape.size(); reshapeOutputShape.assign(higherRank, 1); - int64_t higherLeftIndex = 0; - int64_t higherRightIndex = higherRank; - int64_t lowerLeftIndex = 0; - int64_t lowerRightIndex = lowerRank; - int64_t higherRankDim, lowerRankDim; - - if (lowerRightIndex != 0 && higherRightIndex != 0) { - // Matches lower rank shape from right dimension first, until not - // matching high rank shape or reaching dimension 0. - while (true) { - higherRankDim = higherRankShape[higherRightIndex - 1]; - lowerRankDim = lowerRankShape[lowerRightIndex - 1]; - if (higherRankDim != lowerRankDim) - break; - - reshapeOutputShape[higherRightIndex - 1] = higherRankDim; - - if (higherRightIndex > 0) - higherRightIndex--; - - if (lowerRightIndex > 0) - lowerRightIndex--; - - if (higherRightIndex == 0 || lowerRightIndex == 0) - break; - } - if (lowerRightIndex != 0 && higherRightIndex != 0) { - // Matches lower rank shape from left dimension, until not matching - // high rank shape or reaching right index. - while (true) { - higherRankDim = higherRankShape[higherLeftIndex]; - lowerRankDim = lowerRankShape[lowerLeftIndex]; - if (higherRankDim != lowerRankDim) - break; - - reshapeOutputShape[higherLeftIndex] = higherRankDim; - - if (higherLeftIndex < higherRightIndex) - higherLeftIndex++; - - if (lowerLeftIndex < lowerRightIndex) - lowerLeftIndex++; - - if (higherLeftIndex == higherRightIndex || - lowerLeftIndex == lowerRightIndex) - break; - } - } + int64_t higherRankDim; + int64_t lowerRankDim; + + for (int64_t i = higherRank - 1, j = lowerRank - 1; i >= 0 && j >= 0; + i--, j--) { + higherRankDim = higherRankShape[i]; + lowerRankDim = lowerRankShape[j]; + + if (lowerRankDim == 1 && higherRankDim > 1) + reshapeOutputShape[i] = 1; + else if ((lowerRankDim > 1 && higherRankDim == 1) || + (lowerRankDim == higherRankDim)) + reshapeOutputShape[i] = lowerRankDim; + else if (higherRankDim != lowerRankDim) + return failure(); } + return success(); } /// Common code to create the reshape op where necessary to make the rank of the @@ -143,8 +107,9 @@ SmallVector reshapeOutputShape; - computeReshapeOutput(outputType.getShape(), lowerRankShape, - reshapeOutputShape); + if (computeReshapeOutput(higherRankShape, lowerRankShape, reshapeOutputShape) + .failed()) + return failure(); auto reshapeInputType = lowerTensorValue.getType().cast(); auto reshapeOutputType = RankedTensorType::get( diff --git a/mlir/test/Dialect/Tosa/broadcast.mlir b/mlir/test/Dialect/Tosa/broadcast.mlir --- a/mlir/test/Dialect/Tosa/broadcast.mlir +++ b/mlir/test/Dialect/Tosa/broadcast.mlir @@ -11,7 +11,8 @@ // ----- // CHECK-LABEL: broadcast1 func @test_broadcast1(%arg0: tensor<1xf32>, %arg1: tensor<2x1xf32>) -> tensor<2x1xf32> { - // CHECK: reshape + // CHECK-DAG: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = [1, 1]} + // CHECK: %[[VAR1:.*]] = "tosa.add"(%[[VAR0]], %arg1) %0 = "tosa.add"(%arg0, %arg1) : (tensor<1xf32>, tensor<2x1xf32>) -> tensor<2x1xf32> return %0 : tensor<2x1xf32> } @@ -19,7 +20,8 @@ // ----- // CHECK-LABEL: broadcast2 func @test_broadcast2(%arg0: tensor<2x1xf32>, %arg1: tensor<1xf32>) -> tensor<2x1xf32> { - // CHECK: reshape + // CHECK-DAG: %[[VAR0:.*]] = "tosa.reshape"(%arg1) {new_shape = [1, 1]} + // CHECK: %[[VAR1:.*]] = "tosa.add"(%arg0, %[[VAR0]]) %0 = "tosa.add"(%arg0, %arg1) : (tensor<2x1xf32>, tensor<1xf32>) -> tensor<2x1xf32> return %0 : tensor<2x1xf32> } @@ -27,7 +29,8 @@ // ----- // CHECK-LABEL: broadcast3 func @test_broadcast3(%arg0: tensor<2x1x1x1xf32>, %arg1: tensor<1xf32>) -> tensor<2x1x1x1xf32> { - // CHECK: reshape + // CHECK-DAG: %[[VAR0:.*]] = "tosa.reshape"(%arg1) {new_shape = [1, 1, 1, 1]} + // CHECK: %[[VAR1:.*]] = "tosa.add"(%arg0, %[[VAR0]]) %0 = "tosa.add"(%arg0, %arg1) : (tensor<2x1x1x1xf32>, tensor<1xf32>) -> tensor<2x1x1x1xf32> return %0 : tensor<2x1x1x1xf32> } @@ -35,7 +38,8 @@ // ----- // CHECK-LABEL: broadcast4 func @test_broadcast4(%arg0: tensor<1x1x1x2xf32>, %arg1: tensor<1xf32>) -> tensor<1x1x1x2xf32> { - // CHECK: reshape + // CHECK-DAG: %[[VAR0:.*]] = "tosa.reshape"(%arg1) {new_shape = [1, 1, 1, 1]} + // CHECK: %[[VAR1:.*]] = "tosa.add"(%arg0, %[[VAR0]]) %0 = "tosa.add"(%arg0, %arg1) : (tensor<1x1x1x2xf32>, tensor<1xf32>) -> tensor<1x1x1x2xf32> return %0 : tensor<1x1x1x2xf32> } @@ -43,7 +47,8 @@ // ----- // CHECK-LABEL: broadcast5 func @test_broadcast5(%arg0: tensor<1x1x2x1xf32>, %arg1: tensor<1xf32>) -> tensor<1x1x2x1xf32> { - // CHECK: reshape + // CHECK-DAG: %[[VAR0:.*]] = "tosa.reshape"(%arg1) {new_shape = [1, 1, 1, 1]} + // CHECK: %[[VAR1:.*]] = "tosa.add"(%arg0, %[[VAR0]]) %0 = "tosa.add"(%arg0, %arg1) : (tensor<1x1x2x1xf32>, tensor<1xf32>) -> tensor<1x1x2x1xf32> return %0 : tensor<1x1x2x1xf32> } @@ -51,7 +56,8 @@ // ----- // CHECK-LABEL: broadcast6 func @test_broadcast6(%arg0: tensor<17x16x15x14xf32>, %arg1: tensor<1xf32>) -> tensor<17x16x15x14xf32> { - // CHECK: reshape + // CHECK-DAG: %[[VAR0:.*]] = "tosa.reshape"(%arg1) {new_shape = [1, 1, 1, 1]} + // CHECK: %[[VAR1:.*]] = "tosa.add"(%arg0, %[[VAR0]]) %0 = "tosa.add"(%arg0, %arg1) : (tensor<17x16x15x14xf32>, tensor<1xf32>) -> tensor<17x16x15x14xf32> return %0 : tensor<17x16x15x14xf32> } @@ -59,7 +65,8 @@ // ----- // CHECK-LABEL: broadcast7 func @test_broadcast7(%arg0: tensor<17x16x1x14xf32>, %arg1: tensor<1x1xf32>) -> tensor<17x16x1x14xf32> { - // CHECK: reshape + // CHECK-DAG: %[[VAR0:.*]] = "tosa.reshape"(%arg1) {new_shape = [1, 1, 1, 1]} + // CHECK: %[[VAR1:.*]] = "tosa.add"(%arg0, %[[VAR0]]) %0 = "tosa.add"(%arg0, %arg1) : (tensor<17x16x1x14xf32>, tensor<1x1xf32>) -> tensor<17x16x1x14xf32> return %0 : tensor<17x16x1x14xf32> } @@ -67,7 +74,8 @@ // ----- // CHECK-LABEL: broadcast8 func @test_broadcast8(%arg0: tensor<17x16x15x14xf32>, %arg1: tensor<1x1xf32>) -> tensor<17x16x15x14xf32> { - // CHECK: reshape + // CHECK-DAG: %[[VAR0:.*]] = "tosa.reshape"(%arg1) {new_shape = [1, 1, 1, 1]} + // CHECK: %[[VAR1:.*]] = "tosa.add"(%arg0, %[[VAR0]]) %0 = "tosa.add"(%arg0, %arg1) : (tensor<17x16x15x14xf32>, tensor<1x1xf32>) -> tensor<17x16x15x14xf32> return %0 : tensor<17x16x15x14xf32> } @@ -75,7 +83,8 @@ // ----- // CHECK-LABEL: broadcast9 func @test_broadcast9(%arg0: tensor<17x16x15x14xf32>, %arg1: tensor<15x1xf32>) -> tensor<17x16x15x14xf32> { - // CHECK: reshape + // CHECK-DAG: %[[VAR0:.*]] = "tosa.reshape"(%arg1) {new_shape = [1, 1, 15, 1]} + // CHECK: %[[VAR1:.*]] = "tosa.add"(%arg0, %[[VAR0]]) %0 = "tosa.add"(%arg0, %arg1) : (tensor<17x16x15x14xf32>, tensor<15x1xf32>) -> tensor<17x16x15x14xf32> return %0 : tensor<17x16x15x14xf32> } @@ -83,7 +92,8 @@ // ----- // CHECK-LABEL: broadcast10 func @test_broadcast10(%arg0: tensor<17x16x15x14xf32>, %arg1: tensor<15x14xf32>) -> tensor<17x16x15x14xf32> { - // CHECK: reshape + // CHECK-DAG: %[[VAR0:.*]] = "tosa.reshape"(%arg1) {new_shape = [1, 1, 15, 14]} + // CHECK: %[[VAR1:.*]] = "tosa.add"(%arg0, %[[VAR0]]) %0 = "tosa.add"(%arg0, %arg1) : (tensor<17x16x15x14xf32>, tensor<15x14xf32>) -> tensor<17x16x15x14xf32> return %0 : tensor<17x16x15x14xf32> } @@ -91,7 +101,8 @@ // ----- // CHECK-LABEL: broadcast13 func @test_broadcast13(%arg0: tensor<1xf32>, %arg1: tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32> { - // CHECK: reshape + // CHECK-DAG: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = [1, 1, 1, 1]} + // CHECK: %[[VAR1:.*]] = "tosa.add"(%[[VAR0]], %arg1) %0 = "tosa.add"(%arg0, %arg1) : (tensor<1xf32>, tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32> return %0 : tensor<17x16x15x14xf32> } @@ -99,7 +110,8 @@ // ----- // CHECK-LABEL: broadcast14 func @test_broadcast14(%arg0: tensor<1x1xf32>, %arg1: tensor<17x16x1x14xf32>) -> tensor<17x16x1x14xf32> { - // CHECK: reshape + // CHECK-DAG: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = [1, 1, 1, 1]} + // CHECK: %[[VAR1:.*]] = "tosa.add"(%[[VAR0]], %arg1) %0 = "tosa.add"(%arg0, %arg1) : (tensor<1x1xf32>, tensor<17x16x1x14xf32>) -> tensor<17x16x1x14xf32> return %0 : tensor<17x16x1x14xf32> } @@ -107,7 +119,8 @@ // ----- // CHECK-LABEL: broadcast15 func @test_broadcast15(%arg0: tensor<1x1xf32>, %arg1: tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32> { - // CHECK: reshape + // CHECK-DAG: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = [1, 1, 1, 1]} + // CHECK: %[[VAR1:.*]] = "tosa.add"(%[[VAR0]], %arg1) %0 = "tosa.add"(%arg0, %arg1) : (tensor<1x1xf32>, tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32> return %0 : tensor<17x16x15x14xf32> } @@ -115,7 +128,8 @@ // ----- // CHECK-LABEL: broadcast16 func @test_broadcast16(%arg0: tensor<15x1xf32>, %arg1: tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32> { - // CHECK: reshape + // CHECK-DAG: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = [1, 1, 15, 1]} + // CHECK: %[[VAR1:.*]] = "tosa.add"(%[[VAR0]], %arg1) %0 = "tosa.add"(%arg0, %arg1) : (tensor<15x1xf32>, tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32> return %0 : tensor<17x16x15x14xf32> } @@ -123,7 +137,8 @@ // ----- // CHECK-LABEL: broadcast17 func @test_broadcast17(%arg0: tensor<15x14xf32>, %arg1: tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32> { - // CHECK: reshape + // CHECK-DAG: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = [1, 1, 15, 14]} + // CHECK: %[[VAR1:.*]] = "tosa.add"(%[[VAR0]], %arg1) %0 = "tosa.add"(%arg0, %arg1) : (tensor<15x14xf32>, tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32> return %0 : tensor<17x16x15x14xf32> } @@ -131,24 +146,34 @@ // ----- // CHECK-LABEL: broadcast18 func @test_broadcast18(%arg0: tensor<14x1xf32>, %arg1: tensor<1x15xf32>) -> tensor<14x15xf32> { - // CHECK: add + // CHECK: %[[VAR1:.*]] = "tosa.add"(%arg0, %arg1) %0 = "tosa.add"(%arg0, %arg1) : (tensor<14x1xf32>, tensor<1x15xf32>) -> tensor<14x15xf32> return %0 : tensor<14x15xf32> } // ----- // CHECK-LABEL: broadcast19 -func @broadcast19(%arg0: tensor<64x64x1xf32>, %arg1: tensor<1x17xf32>) -> (tensor<64x64x17xf32> ) { - // CHECK: reshape - // CHECK: sub +func @test_broadcast19(%arg0: tensor<64x64x1xf32>, %arg1: tensor<1x17xf32>) -> (tensor<64x64x17xf32> ) { + // CHECK-DAG: %[[VAR0:.*]] = "tosa.reshape"(%arg1) {new_shape = [1, 1, 17]} + // CHECK: %[[VAR1:.*]] = "tosa.sub"(%arg0, %[[VAR0]]) %0 = "tosa.sub"(%arg0, %arg1) : (tensor<64x64x1xf32>, tensor<1x17xf32>) -> tensor<64x64x17xf32> return %0 : tensor<64x64x17xf32> } +// ----- +// CHECK-LABEL: broadcast20 +func @test_broadcast20(%arg0: tensor<3x3x4x1xf32>, %arg1: tensor<4x5xf32>) -> (tensor<3x3x4x5xf32> ) { + // CHECK-DAG: %[[VAR0:.*]] = "tosa.reshape"(%arg1) {new_shape = [1, 1, 4, 5]} + // CHECK: %[[VAR1:.*]] = "tosa.add"(%arg0, %[[VAR0]]) + %0 = "tosa.add"(%arg0, %arg1) : (tensor<3x3x4x1xf32>, tensor<4x5xf32>) -> tensor<3x3x4x5xf32> + return %0 : tensor<3x3x4x5xf32> +} + // ----- // CHECK-LABEL: broadcast_mul func @test_broadcast_mul(%arg0: tensor<15x14xi32>, %arg1: tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32> { - // CHECK: reshape + // CHECK-DAG: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = [1, 1, 15, 14]} + // CHECK: %[[VAR1:.*]] = "tosa.mul"(%[[VAR0]], %arg1) %0 = "tosa.mul"(%arg0, %arg1) {shift = 1 : i32 } : (tensor<15x14xi32>, tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32> return %0 : tensor<17x16x15x14xi32> } @@ -156,7 +181,8 @@ // ----- // CHECK-LABEL: broadcast_arithmetic_right_shift func @test_broadcast_arithmetic_right_shift(%arg0: tensor<15x14xi32>, %arg1: tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32> { - // CHECK: reshape + // CHECK-DAG: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = [1, 1, 15, 14]} + // CHECK: %[[VAR1:.*]] = "tosa.arithmetic_right_shift"(%[[VAR0]], %arg1) %0 = "tosa.arithmetic_right_shift"(%arg0, %arg1) { round = true } : (tensor<15x14xi32>, tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32> return %0 : tensor<17x16x15x14xi32> } @@ -164,7 +190,8 @@ // ----- // CHECK-LABEL: broadcast_scalar func @test_broadcast_scalar(%arg0: tensor, %arg1: tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32> { - // CHECK-NEXT: reshape + // CHECK-DAG: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = [1, 1, 1, 1]} + // CHECK: %[[VAR1:.*]] = "tosa.add"(%[[VAR0]], %arg1) %0 = "tosa.add"(%arg0, %arg1) : (tensor, tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32> return %0 : tensor<17x16x15x14xi32> }