diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -1500,6 +1500,7 @@ ); let hasCanonicalizer = 1; + let hasFolder = 1; let extraClassDeclaration = [{ /// Returns true when two result types are compatible for this op; diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -1037,3 +1037,37 @@ return {}; } + +OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) { + // Fold consecutive concats on the same axis into a single op. + // Keep track of the operands so we are able to construct a new concat + // later. Conservatively assume that we double the number of operands when + // folding + SmallVector concatOperands; + concatOperands.reserve(2 * getNumOperands()); + + // Find all operands that are foldable concats + bool foundFoldableConcat = false; + for (Value operand : getOperands()) { + concatOperands.emplace_back(operand); + + auto producer = dyn_cast_or_null(operand.getDefiningOp()); + if (!producer) + continue; + + // Not foldable if axes are not the same + if (getAxis() != producer.getAxis()) + continue; + + // Replace the original operand with all incoming operands + foundFoldableConcat = true; + concatOperands.pop_back(); + llvm::append_range(concatOperands, producer->getOperands()); + } + + if (!foundFoldableConcat) + return {}; + + getOperation()->setOperands(concatOperands); + return getResult(); +} diff --git a/mlir/test/Dialect/Tosa/fold_concats.mlir b/mlir/test/Dialect/Tosa/fold_concats.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Tosa/fold_concats.mlir @@ -0,0 +1,93 @@ +// RUN: mlir-opt --split-input-file --canonicalize %s | FileCheck %s + +func.func @single_concat(%arg0: tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32> { + %0 = "tosa.concat"(%arg0, %arg0) {axis = 1} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32> + return %0 : tensor<1x2x7x7xf32> +} + +// CHECK-LABEL: func.func @single_concat( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32> { +// CHECK: %[[VAL_1:.*]] = "tosa.concat"(%[[VAL_0]], %[[VAL_0]]) <{axis = 1 : i64}> : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32> +// CHECK: return %[[VAL_1]] : tensor<1x2x7x7xf32> +// CHECK: } + +// ----- + +func.func @concat_different_axis(%arg0: tensor<1x1x7x7xf32>) -> tensor<2x2x7x7xf32> { + %0 = "tosa.concat"(%arg0, %arg0) {axis = 1} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32> + %1 = "tosa.concat"(%0, %0) {axis = 0} : (tensor<1x2x7x7xf32>, tensor<1x2x7x7xf32>) -> tensor<2x2x7x7xf32> + return %1 : tensor<2x2x7x7xf32> +} + +// CHECK-LABEL: func.func @concat_different_axis( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x7x7xf32>) -> tensor<2x2x7x7xf32> { +// CHECK: %[[VAL_1:.*]] = "tosa.concat"(%[[VAL_0]], %[[VAL_0]]) <{axis = 1 : i64}> : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32> +// CHECK: %[[VAL_2:.*]] = "tosa.concat"(%[[VAL_1]], %[[VAL_1]]) <{axis = 0 : i64}> : (tensor<1x2x7x7xf32>, tensor<1x2x7x7xf32>) -> tensor<2x2x7x7xf32> +// CHECK: return %[[VAL_2]] : tensor<2x2x7x7xf32> +// CHECK: } + +// ----- + +func.func @fold_concats(%arg0: tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32> { + %tmp = tensor.empty() : tensor<1x1x7x7xf32> + %0 = "tosa.concat"(%arg0, %arg0) {axis = 1} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32> + %1 = "tosa.concat"(%tmp, %0, %tmp) {axis = 1} : (tensor<1x1x7x7xf32>, tensor<1x2x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32> + return %1 : tensor<1x4x7x7xf32> +} + +// CHECK-LABEL: func.func @fold_concats( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32> { +// CHECK: %[[VAL_1:.*]] = tensor.empty() : tensor<1x1x7x7xf32> +// CHECK: %[[VAL_2:.*]] = "tosa.concat"(%[[VAL_1]], %[[VAL_0]], %[[VAL_0]], %[[VAL_1]]) <{axis = 1 : i64}> : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32> +// CHECK: return %[[VAL_2]] : tensor<1x4x7x7xf32> +// CHECK: } + +// ----- + +func.func @nested_fold(%arg0: tensor<1x1x7x7xf32>) -> tensor<1x8x7x7xf32> { + %tmp = tensor.empty() : tensor<1x1x7x7xf32> + %0 = "tosa.concat"(%arg0, %arg0) {axis = 1} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32> + %1 = "tosa.concat"(%tmp, %0, %tmp) {axis = 1} : (tensor<1x1x7x7xf32>, tensor<1x2x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32> + %2 = "tosa.concat"(%1, %1) {axis = 1} : (tensor<1x4x7x7xf32>, tensor<1x4x7x7xf32>) -> tensor<1x8x7x7xf32> + return %2 : tensor<1x8x7x7xf32> +} + +// CHECK-LABEL: func.func @nested_fold( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x7x7xf32>) -> tensor<1x8x7x7xf32> { +// CHECK: %[[VAL_1:.*]] = tensor.empty() : tensor<1x1x7x7xf32> +// CHECK: %[[VAL_2:.*]] = "tosa.concat"(%[[VAL_1]], %[[VAL_0]], %[[VAL_0]], %[[VAL_1]], %[[VAL_1]], %[[VAL_0]], %[[VAL_0]], %[[VAL_1]]) <{axis = 1 : i64}> : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x8x7x7xf32> +// CHECK: return %[[VAL_2]] : tensor<1x8x7x7xf32> +// CHECK: } + +// ----- + +func.func @wide_fold(%arg0: tensor<1x1x7x7xf32>, %arg1: tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32> { + %0 = "tosa.concat"(%arg0, %arg0) {axis = 1} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32> + %1 = "tosa.concat"(%arg1, %arg1) {axis = 1} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32> + %2 = "tosa.concat"(%0, %1) {axis = 1} : (tensor<1x2x7x7xf32>, tensor<1x2x7x7xf32>) -> tensor<1x4x7x7xf32> + return %2 : tensor<1x4x7x7xf32> +} + +// CHECK-LABEL: func.func @wide_fold( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x7x7xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32> { +// CHECK: %[[VAL_2:.*]] = "tosa.concat"(%[[VAL_0]], %[[VAL_0]], %[[VAL_1]], %[[VAL_1]]) <{axis = 1 : i64}> : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32> +// CHECK: return %[[VAL_2]] : tensor<1x4x7x7xf32> +// CHECK: } + +// ----- + +func.func @partially_foldable(%arg0: tensor<1x1x8x8xf32>, %arg1: tensor<1x2x4x8xf32>) -> tensor<1x4x8x8xf32> { + %0 = "tosa.concat"(%arg0, %arg0) {axis = 1} : (tensor<1x1x8x8xf32>, tensor<1x1x8x8xf32>) -> tensor<1x2x8x8xf32> + %1 = "tosa.concat"(%arg1, %arg1) {axis = 2} : (tensor<1x2x4x8xf32>, tensor<1x2x4x8xf32>) -> tensor<1x2x8x8xf32> + %2 = "tosa.concat"(%0, %1) {axis = 1} : (tensor<1x2x8x8xf32>, tensor<1x2x8x8xf32>) -> tensor<1x4x8x8xf32> + return %2 : tensor<1x4x8x8xf32> +} + +// CHECK-LABEL: func.func @partially_foldable( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x8x8xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x2x4x8xf32>) -> tensor<1x4x8x8xf32> { +// CHECK: %[[VAL_2:.*]] = "tosa.concat"(%[[VAL_1]], %[[VAL_1]]) <{axis = 2 : i64}> : (tensor<1x2x4x8xf32>, tensor<1x2x4x8xf32>) -> tensor<1x2x8x8xf32> +// CHECK: %[[VAL_3:.*]] = "tosa.concat"(%[[VAL_0]], %[[VAL_0]], %[[VAL_2]]) <{axis = 1 : i64}> : (tensor<1x1x8x8xf32>, tensor<1x1x8x8xf32>, tensor<1x2x8x8xf32>) -> tensor<1x4x8x8xf32> +// CHECK: return %[[VAL_3]] : tensor<1x4x8x8xf32> +// CHECK: }