diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -1139,10 +1139,8 @@ //===----------------------------------------------------------------------===// // Operator: equal //===----------------------------------------------------------------------===// -def Tosa_EqualOp : Tosa_Op<"equal", [ - DeclareOpInterfaceMethods, - ResultsBroadcastableShape, Commutative, NoSideEffect]> { +def Tosa_EqualOp : Tosa_Op<"equal", [InferTensorType, ResultsBroadcastableShape, + Commutative, NoSideEffect]> { let summary = "Returns the truth value of (x == y) element-wise."; let description = [{ diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -337,6 +337,42 @@ } } +static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands, + SmallVector &outShape) { + int64_t outRank = 0; + for (int i = 0, e = operands.size(); i != e; ++i) { + auto shape = operands.getShape(i); + if (!shape.hasRank()) { + return failure(); + } + outRank = std::max(outRank, shape.getRank()); + } + + outShape.resize(outRank, 1); + + for (int i = 0, e = operands.size(); i != e; ++i) { + auto shape = operands.getShape(i); + auto rankDiff = outShape.size() - shape.getRank(); + + for (size_t i = 0, e = shape.getRank(); i < e; ++i) { + auto dim1 = outShape[i + rankDiff]; + auto dim2 = shape.getDimSize(i); + auto resolvedDim = dim1; + + if (dim1 == 1) { + resolvedDim = dim2; + } else if (dim2 == 1) { + resolvedDim = dim1; + } else if (dim1 != dim2) { + return failure(); + } + outShape[i + rankDiff] = resolvedDim; + } + } + + return success(); +} + LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents( MLIRContext *context, ::llvm::Optional location, ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, @@ -419,6 +455,25 @@ return success(); } +LogicalResult tosa::EqualOp::inferReturnTypeComponents( + MLIRContext *context, ::llvm::Optional location, + ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnShapes) { + llvm::SmallVector outShape; + if (resolveBroadcastShape(operands, outShape).failed()) { + inferredReturnShapes.push_back(ShapedTypeComponents()); + } else { + ShapedTypeComponents &components = inferredReturnShapes.emplace_back( + IntegerType::get(context, /*width=*/1)); + auto argTy = operands.front().getType().cast(); + if (argTy.hasRank()) { + components = + ShapedTypeComponents(argTy.getShape(), components.getElementType()); + } + } + return success(); +} + LogicalResult tosa::FullyConnectedOp::inferReturnTypeComponents( MLIRContext *context, ::llvm::Optional location, ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, @@ -868,42 +923,6 @@ REDUCE_SHAPE_INFER(tosa::ReduceSumOp) #undef REDUCE_SHAPE_INFER -static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands, - SmallVector &outShape) { - int64_t outRank = 0; - for (int i = 0, e = operands.size(); i != e; ++i) { - auto shape = operands.getShape(i); - if (!shape.hasRank()) { - return failure(); - } - outRank = std::max(outRank, shape.getRank()); - } - - outShape.resize(outRank, 1); - - for (int i = 0, e = operands.size(); i != e; ++i) { - auto shape = operands.getShape(i); - auto rankDiff = outShape.size() - shape.getRank(); - - for (size_t i = 0, e = shape.getRank(); i < e; ++i) { - auto dim1 = outShape[i + rankDiff]; - auto dim2 = shape.getDimSize(i); - auto resolvedDim = dim1; - - if (dim1 == 1) { - resolvedDim = dim2; - } else if (dim2 == 1) { - resolvedDim = dim1; - } else if (dim1 != dim2) { - return failure(); - } - outShape[i + rankDiff] = resolvedDim; - } - } - - return success(); -} - static LogicalResult NAryInferReturnTypes( const ValueShapeRange &operands, SmallVectorImpl &inferredReturnShapes) { @@ -937,7 +956,6 @@ NARY_SHAPE_INFER(tosa::ClampOp) NARY_SHAPE_INFER(tosa::ClzOp) NARY_SHAPE_INFER(tosa::DivOp) -NARY_SHAPE_INFER(tosa::EqualOp) NARY_SHAPE_INFER(tosa::ExpOp) NARY_SHAPE_INFER(tosa::FloorOp) NARY_SHAPE_INFER(tosa::GreaterEqualOp)