diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -1460,6 +1460,8 @@ let results = (outs Tosa_Tensor1Dto4D:$output ); + + let hasFolder = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -805,6 +805,19 @@ return {}; } +OpFoldResult ReverseOp::fold(ArrayRef operands) { + auto operand = getInput(); + auto operandTy = operand.getType().cast(); + auto axis = getAxis(); + auto operandAttr = operands[0].dyn_cast_or_null(); + if (operandAttr) return operandAttr; + + // If the dim-length is 1, tosa.reverse is a no-op. + if (operandTy.hasRank() && operandTy.getDimSize(axis) == 1) return operand; + + return {}; +} + OpFoldResult SliceOp::fold(ArrayRef operands) { auto inputTy = getInput().getType().dyn_cast(); auto outputTy = getType().dyn_cast(); diff --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir --- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir +++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir @@ -482,3 +482,25 @@ // CHECK: return %[[SPLAT]] return %cast : tensor } + +// ----- + +// CHECK-LABEL: @reverse_splat +func.func @reverse_splat() -> tensor<10xi32> { + // CHECK: %[[SPLAT:.+]] = "tosa.const"() {value = dense<42> : tensor<10xi32>} + %splat = "tosa.const"() {value = dense<42> : tensor<10xi32>} : () -> tensor<10xi32> + %reverse = "tosa.reverse"(%splat) { axis = 0 : i64 } : (tensor<10xi32>) -> tensor<10xi32> + // CHECK: return %[[SPLAT]] + return %reverse : tensor<10xi32> +} + +// ----- + +// CHECK-LABEL: @reverse_length_one +func.func @reverse_length_one(%arg0 : tensor<10x1xi32>) -> (tensor<10x1xi32>, tensor<10x1xi32>) { + %nofold = "tosa.reverse"(%arg0) { axis = 0 : i64 } : (tensor<10x1xi32>) -> tensor<10x1xi32> + %fold = "tosa.reverse"(%arg0) { axis = 1 : i64 } : (tensor<10x1xi32>) -> tensor<10x1xi32> + // CHECK: %[[NOFOLD:.+]] = "tosa.reverse"(%arg0) {axis = 0 : i64} + // CHECK: return %[[NOFOLD]], %arg0 + return %nofold, %fold : tensor<10x1xi32>, tensor<10x1xi32> +}