diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -1139,10 +1139,8 @@ //===----------------------------------------------------------------------===// // Operator: equal //===----------------------------------------------------------------------===// -def Tosa_EqualOp : Tosa_Op<"equal", [ - DeclareOpInterfaceMethods, - ResultsBroadcastableShape, Commutative, NoSideEffect]> { +def Tosa_EqualOp : Tosa_Op<"equal", [InferTensorType, ResultsBroadcastableShape, + Commutative, NoSideEffect]> { let summary = "Returns the truth value of (x == y) element-wise."; let description = [{ @@ -1157,6 +1155,12 @@ let results = (outs I1Tensor:$output ); + + let extraClassDeclaration = [{ + /// Returns when two result types are compatible for this op; method used by + /// InferTypeOpInterface. + static bool isCompatibleReturnTypes(TypeRange l, TypeRange r); + }]; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -21,6 +21,7 @@ #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/IR/TypeUtilities.h" #include "mlir/Transforms/InliningUtils.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/TypeSwitch.h" @@ -339,6 +340,44 @@ } } +static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands, + SmallVector &outShape) { + int64_t outRank = 0; + for (int i = 0, e = operands.size(); i != e; ++i) { + auto shape = operands.getShape(i); + if (!shape.hasRank()) { + // TODO(jennik): Update function to have better case handling for invalid + // operands and for ranked tensors. + return failure(); + } + outRank = std::max(outRank, shape.getRank()); + } + + outShape.resize(outRank, 1); + + for (int i = 0, e = operands.size(); i != e; ++i) { + auto shape = operands.getShape(i); + auto rankDiff = outShape.size() - shape.getRank(); + + for (size_t i = 0, e = shape.getRank(); i < e; ++i) { + auto dim1 = outShape[i + rankDiff]; + auto dim2 = shape.getDimSize(i); + auto resolvedDim = dim1; + + if (dim1 == 1) { + resolvedDim = dim2; + } else if (dim2 == 1) { + resolvedDim = dim1; + } else if (dim1 != dim2) { + return failure(); + } + outShape[i + rankDiff] = resolvedDim; + } + } + + return success(); +} + LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents( MLIRContext *context, ::llvm::Optional location, ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, @@ -421,6 +460,27 @@ return success(); } +LogicalResult tosa::EqualOp::inferReturnTypeComponents( + MLIRContext *context, ::llvm::Optional location, + ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnShapes) { + llvm::SmallVector outShape; + if (resolveBroadcastShape(operands, outShape).failed()) { + inferredReturnShapes.push_back(ShapedTypeComponents()); + return success(); + } + + inferredReturnShapes.push_back( + ShapedTypeComponents(outShape, IntegerType::get(context, /*width=*/1))); + return success(); +} + +bool tosa::EqualOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { + if (l.size() != r.size() || l.size() != 1) + return false; + return succeeded(verifyCompatibleShape(l[0], r[0])); +} + LogicalResult tosa::FullyConnectedOp::inferReturnTypeComponents( MLIRContext *context, ::llvm::Optional location, ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, @@ -870,42 +930,6 @@ REDUCE_SHAPE_INFER(tosa::ReduceSumOp) #undef REDUCE_SHAPE_INFER -static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands, - SmallVector &outShape) { - int64_t outRank = 0; - for (int i = 0, e = operands.size(); i != e; ++i) { - auto shape = operands.getShape(i); - if (!shape.hasRank()) { - return failure(); - } - outRank = std::max(outRank, shape.getRank()); - } - - outShape.resize(outRank, 1); - - for (int i = 0, e = operands.size(); i != e; ++i) { - auto shape = operands.getShape(i); - auto rankDiff = outShape.size() - shape.getRank(); - - for (size_t i = 0, e = shape.getRank(); i < e; ++i) { - auto dim1 = outShape[i + rankDiff]; - auto dim2 = shape.getDimSize(i); - auto resolvedDim = dim1; - - if (dim1 == 1) { - resolvedDim = dim2; - } else if (dim2 == 1) { - resolvedDim = dim1; - } else if (dim1 != dim2) { - return failure(); - } - outShape[i + rankDiff] = resolvedDim; - } - } - - return success(); -} - static LogicalResult NAryInferReturnTypes( const ValueShapeRange &operands, SmallVectorImpl &inferredReturnShapes) { @@ -939,7 +963,6 @@ NARY_SHAPE_INFER(tosa::ClampOp) NARY_SHAPE_INFER(tosa::ClzOp) NARY_SHAPE_INFER(tosa::DivOp) -NARY_SHAPE_INFER(tosa::EqualOp) NARY_SHAPE_INFER(tosa::ExpOp) NARY_SHAPE_INFER(tosa::FloorOp) NARY_SHAPE_INFER(tosa::GreaterEqualOp)