diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -717,9 +717,18 @@ auto inputTy = getInput1().getType().dyn_cast(); auto outputTy = getType().dyn_cast(); - if (!inputTy || !outputTy || inputTy != outputTy) + if (!inputTy || !outputTy) return {}; - return getInput1(); + + if (inputTy == outputTy) + return getInput1(); + + auto operand = operands[0].dyn_cast_or_null(); + if (operand && outputTy.hasStaticShape() && operand.isSplat()) { + return SplatElementsAttr::get(outputTy, operand.getSplatValue()); + } + + return {}; } OpFoldResult PadOp::fold(ArrayRef operands) { diff --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir --- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir +++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir @@ -398,6 +398,16 @@ // ----- +func.func @reshape_splat() -> tensor<6x5x4xi32> { + // CHECK: %[[SPLAT:.+]] = "tosa.const"() {value = dense<42> : tensor<6x5x4xi32>} + %splat = "tosa.const"() {value = dense<42> : tensor<4x5x6xi32>} : () -> tensor<4x5x6xi32> + %reshape = "tosa.reshape"(%splat) { new_shape = [6, 5, 4] } : (tensor<4x5x6xi32>) -> tensor<6x5x4xi32> + // CHECK: return %[[SPLAT]] + return %reshape : tensor<6x5x4xi32> +} + +// ----- + // CHECK-LABEL: @slice_splat func.func @slice_splat() -> tensor<1x1x1xi32> { // CHECK: %[[SLICE:.+]] = "tosa.const"() {value = dense<42> : tensor<1x1x1xi32>}