diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -1139,10 +1139,8 @@ //===----------------------------------------------------------------------===// // Operator: equal //===----------------------------------------------------------------------===// -def Tosa_EqualOp : Tosa_Op<"equal", [ - DeclareOpInterfaceMethods, - ResultsBroadcastableShape, Commutative, NoSideEffect]> { +def Tosa_EqualOp : Tosa_Op<"equal", [InferTensorType, ResultsBroadcastableShape, + Commutative, NoSideEffect]> { let summary = "Returns the truth value of (x == y) element-wise."; let description = [{ diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -937,7 +937,19 @@ NARY_SHAPE_INFER(tosa::ClampOp) NARY_SHAPE_INFER(tosa::ClzOp) NARY_SHAPE_INFER(tosa::DivOp) -NARY_SHAPE_INFER(tosa::EqualOp) +LogicalResult tosa::EqualOp::inferReturnTypeComponents( + MLIRContext *context, ::llvm::Optional location, + ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnShapes) { + ShapedTypeComponents& components = + inferredReturnShapes.emplace_back(IntegerType::get(context, /*width=*/1)); + auto argTy = operands.front().getType().cast(); + if (argTy.hasRank()) { + components = + ShapedTypeComponents(argTy.getShape(), components.getElementType()); + } + return success(); +} NARY_SHAPE_INFER(tosa::ExpOp) NARY_SHAPE_INFER(tosa::FloorOp) NARY_SHAPE_INFER(tosa::GreaterEqualOp)