diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td @@ -223,10 +223,13 @@ } class Tosa_ElemWiseBinaryOp traits = []> : - Tosa_Op, + Tosa_Op { + let extraClassDeclaration = [{ + /// Returns true when two result types are compatible for this op; + /// Method used by InferTypeOpInterface. + static bool isCompatibleReturnTypes(TypeRange l, TypeRange r); + }]; } #endif // TOSA_OP_BASE diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -968,6 +968,39 @@ return success(); } +static bool binaryIsCompatibleReturnTypes(TypeRange inferred, + TypeRange returned) { + if (inferred.size() != returned.size() || inferred.size() != 1) + return false; + // return succeeded(verifyCompatibleShape(l[0], r[0])); + auto inferredType = llvm::dyn_cast(inferred[0]); + auto returnedType = llvm::dyn_cast(returned[0]); + + // Either both or neither type should be shaped. + if (!inferredType) + return !inferredType; + if (!inferredType) + return false; + + if (!inferredType.hasRank() || !inferredType.hasRank()) + return true; + + ArrayRef shape1 = inferredType.getShape(); + ArrayRef shape2 = inferredType.getShape(); + if (shape1.size() != shape2.size()) + return false; + if (returnedType.hasStaticShape() && !inferredType.hasStaticShape()) + return false; + for (auto dims : llvm::zip(shape1, shape2)) { + int64_t dim1 = std::get<0>(dims); + int64_t dim2 = std::get<1>(dims); + if (!ShapedType::isDynamic(dim1) && !ShapedType::isDynamic(dim2) && + dim1 != dim2) + return false; + } + return true; +} + #define COMPATIBLE_RETURN_TYPES(OP) \ bool OP::isCompatibleReturnTypes(TypeRange l, TypeRange r) { \ if (l.size() != r.size() || l.size() != 1) \ @@ -1005,10 +1038,11 @@ const ValueShapeRange &operands, SmallVectorImpl &inferredReturnShapes) { llvm::SmallVector outShape; + Type inputType = operands.getType()[0].cast().getElementType(); if (resolveBroadcastShape(operands, outShape).failed()) { - inferredReturnShapes.push_back(ShapedTypeComponents()); + inferredReturnShapes.push_back(ShapedTypeComponents(inputType)); } else { - inferredReturnShapes.push_back(ShapedTypeComponents(outShape)); + inferredReturnShapes.push_back(ShapedTypeComponents(outShape, inputType)); } return success(); } @@ -1022,6 +1056,28 @@ return NAryInferReturnTypes(operands, inferredReturnShapes); \ } +#define BINARY_COMPATIBLE_TYPES(OP) \ + bool OP::isCompatibleReturnTypes(TypeRange l, TypeRange r) { \ + return binaryIsCompatibleReturnTypes(l, r); \ + } + +BINARY_COMPATIBLE_TYPES(tosa::AddOp) +BINARY_COMPATIBLE_TYPES(tosa::ArithmeticRightShiftOp) +BINARY_COMPATIBLE_TYPES(tosa::BitwiseAndOp) +BINARY_COMPATIBLE_TYPES(tosa::BitwiseOrOp) +BINARY_COMPATIBLE_TYPES(tosa::BitwiseXorOp) +BINARY_COMPATIBLE_TYPES(tosa::DivOp) +BINARY_COMPATIBLE_TYPES(tosa::LogicalAndOp) +BINARY_COMPATIBLE_TYPES(tosa::LogicalLeftShiftOp) +BINARY_COMPATIBLE_TYPES(tosa::LogicalOrOp) +BINARY_COMPATIBLE_TYPES(tosa::LogicalRightShiftOp) +BINARY_COMPATIBLE_TYPES(tosa::LogicalXorOp) +BINARY_COMPATIBLE_TYPES(tosa::MaximumOp) +BINARY_COMPATIBLE_TYPES(tosa::MinimumOp) +BINARY_COMPATIBLE_TYPES(tosa::MulOp) +BINARY_COMPATIBLE_TYPES(tosa::PowOp) +BINARY_COMPATIBLE_TYPES(tosa::SubOp) +#undef BINARY_COMPATIBLE_TYPES NARY_SHAPE_INFER(tosa::AbsOp) NARY_SHAPE_INFER(tosa::AddOp) NARY_SHAPE_INFER(tosa::ArithmeticRightShiftOp) @@ -1060,7 +1116,7 @@ NARY_SHAPE_INFER(tosa::TanhOp) NARY_SHAPE_INFER(tosa::ErfOp) NARY_SHAPE_INFER(tosa::SigmoidOp) -#undef PRED_SHAPE_INFER +#undef NARY_SHAPE_INFER static LogicalResult poolingInferReturnTypes( const ValueShapeRange &operands, DictionaryAttr attributes,