diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -491,11 +491,30 @@ auto inputTy = getInput().getType().dyn_cast(); auto outputTy = getType().dyn_cast(); - if (!inputTy || !outputTy || inputTy != outputTy) + if (!inputTy || !outputTy) return {}; - if (inputTy.hasStaticShape()) + + if (inputTy == outputTy && inputTy.hasStaticShape()) return getInput(); + if (!operands[0]) + return {}; + + auto operand = operands[0].cast(); + if (operand.isSplat() && outputTy.hasStaticShape()) { + return SplatElementsAttr::get(outputTy, operand.getSplatValue()); + } + + if (inputTy.hasStaticShape() && outputTy.hasStaticShape() && + outputTy.getNumElements() == 1) { + llvm::SmallVector indices; + for (auto val : getStart()) { + indices.push_back(val.cast().getInt()); + } + auto value = operand.getValues()[indices]; + return SplatElementsAttr::get(outputTy, value); + } + return {}; } diff --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir --- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir +++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir @@ -97,3 +97,26 @@ %0 = "tosa.transpose"(%input, %perms) : (tensor<1x1x1x16xi8>, tensor<4xi32>) -> tensor<1x1x16x1x!quant.uniform:f32:3, {1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,2.100000e+00,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01}>> return %0: tensor<1x1x16x1x!quant.uniform:f32:3, {1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,2.100000e+00,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01}>> } + +// ----- + +// CHECK-LABEL: @slice_splat +func.func @slice_splat() -> tensor<1x1x1xi32> { + // CHECK: %[[SLICE:.+]] = "tosa.const"() {value = dense<42> : tensor<1x1x1xi32>} + %splat = "tosa.const"() {value = dense<42> : tensor<4x5x6xi32>} : () -> tensor<4x5x6xi32> + %slice = "tosa.slice"(%splat) { size = [1, 1, 1], start = [1, 2, 3] } : (tensor<4x5x6xi32>) -> tensor<1x1x1xi32> + // CHECK: return %[[SLICE]] + return %slice : tensor<1x1x1xi32> +} + +// ----- + +// CHECK-LABEL: @slice_singleton +func.func @slice_singleton() -> tensor<1x1xi32> { + %splat = "tosa.const"() {value = dense<[[0, 1, 2], [3, 4, 5], [6, 7 ,8]]> : tensor<3x3xi32>} : () -> tensor<3x3xi32> + // CHECK: %[[SLICE:.+]] = "tosa.const"() {value = dense<4> : tensor<1x1xi32>} + %slice = "tosa.slice"(%splat) { size = [1, 1], start = [1, 1] } : (tensor<3x3xi32>) -> tensor<1x1xi32> + // CHECK: return %[[SLICE]] + return %slice : tensor<1x1xi32> +} +