diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -803,8 +803,8 @@ // If we are comparing an integer value to itself it is always true. We can // not do this with float due to float values. - if (lhsTy.getElementType().isa() && resultTy.hasStaticShape() && - lhs == rhs) { + if (lhsTy.getElementType().isa() && resultTy && + resultTy.hasStaticShape() && lhs == rhs) { return DenseElementsAttr::get(resultTy, true); } diff --git a/mlir/test/Dialect/Tosa/constant_folding.mlir b/mlir/test/Dialect/Tosa/constant_folding.mlir --- a/mlir/test/Dialect/Tosa/constant_folding.mlir +++ b/mlir/test/Dialect/Tosa/constant_folding.mlir @@ -6,3 +6,11 @@ %0 = "tosa.const"() {value = dense<[3, 0, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32> return %0 : tensor<4xi32> } + +// CHECK-LABEL: func @try_fold_equal_with_unranked_tensor +func.func @try_fold_equal_with_unranked_tensor(%arg0: tensor<4xi32>, %arg1: tensor) { + // CHECK: "tosa.equal" + // CHECK-NEXT: return + %0 = "tosa.equal"(%arg0, %arg1) : (tensor<4xi32>, tensor) -> tensor<*xi1> + return +}