diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -486,14 +486,15 @@ MLIRContext *context, ::std::optional location, ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { + auto elementType = IntegerType::get(context, /*width=*/1); + llvm::SmallVector outShape; if (resolveBroadcastShape(operands, outShape).failed()) { - inferredReturnShapes.push_back(ShapedTypeComponents()); + inferredReturnShapes.push_back(ShapedTypeComponents(elementType)); return success(); } - inferredReturnShapes.push_back( - ShapedTypeComponents(outShape, IntegerType::get(context, /*width=*/1))); + inferredReturnShapes.push_back(ShapedTypeComponents(outShape, elementType)); return success(); } diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir --- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir +++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir @@ -1224,3 +1224,13 @@ %output_real, %output_imag = "tosa.fft2d"(%arg0, %arg1) {inverse = false} : (tensor, tensor) -> (tensor, tensor) return %output_real, %output_imag : tensor, tensor } + +// ----- + +// CHECK-LABEL: @test_unranked_equal +func.func @test_unranked_equal(%arg0 : tensor<*xf32>, %arg1 : tensor) -> () { + // CHECK: "tosa.equal"(%arg0, %arg1) : (tensor<*xf32>, tensor) -> tensor<*xi1> + %0 = "tosa.equal"(%arg0, %arg1) : (tensor<*xf32>, tensor) -> tensor<*xi1> + + return +}